Add "in" filter type

This commit is contained in:
kolaente 2020-12-19 17:24:40 +01:00
parent 92bcce3f7c
commit 41cd3339c0
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
3 changed files with 60 additions and 17 deletions

View File

@ -171,6 +171,7 @@
namespace_id: 14
updated: 2018-12-02 15:13:12
created: 2018-12-01 15:13:12
# User 1 does not have access to this list
-
id: 20
title: Test20

View File

@ -21,6 +21,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
"code.vikunja.io/api/pkg/config"
@ -40,6 +41,7 @@ const (
taskFilterComparatorLessEquals taskFilterComparator = "<="
taskFilterComparatorNotEquals taskFilterComparator = "!="
taskFilterComparatorLike taskFilterComparator = "like"
taskFilterComparatorIn taskFilterComparator = "in"
)
type taskFilter struct {
@ -89,7 +91,7 @@ func getTaskFiltersByCollections(c *TaskCollection) (filters []*taskFilter, err
// Cast the field value to its native type
if len(c.FilterValue) > i {
filter.value, err = getNativeValueForTaskField(filter.field, c.FilterValue[i])
filter.value, err = getNativeValueForTaskField(filter.field, filter.comparator, c.FilterValue[i])
if err != nil {
return nil, ErrInvalidTaskFilterValue{
Value: filter.field,
@ -113,7 +115,8 @@ func validateTaskFieldComparator(comparator taskFilterComparator) error {
taskFilterComparatorLess,
taskFilterComparatorLessEquals,
taskFilterComparatorNotEquals,
taskFilterComparatorLike:
taskFilterComparatorLike,
taskFilterComparatorIn:
return nil
case taskFilterComparatorInvalid:
fallthrough
@ -138,42 +141,67 @@ func getFilterComparatorFromString(comparator string) (taskFilterComparator, err
return taskFilterComparatorNotEquals, nil
case "like":
return taskFilterComparatorLike, nil
case "in":
return taskFilterComparatorIn, nil
default:
return taskFilterComparatorInvalid, ErrInvalidTaskFilterComparator{Comparator: taskFilterComparator(comparator)}
}
}
func getNativeValueForTaskField(fieldName, value string) (nativeValue interface{}, err error) {
field, ok := reflect.TypeOf(&Task{}).Elem().FieldByName(strcase.ToCamel(fieldName))
if !ok {
return nil, ErrInvalidTaskField{TaskField: fieldName}
}
func getValueForField(field reflect.StructField, rawValue string) (value interface{}, err error) {
switch field.Type.Kind() {
case reflect.Int64:
nativeValue, err = strconv.ParseInt(value, 10, 64)
value, err = strconv.ParseInt(rawValue, 10, 64)
case reflect.Float64:
nativeValue, err = strconv.ParseFloat(value, 64)
value, err = strconv.ParseFloat(rawValue, 64)
case reflect.String:
nativeValue = value
// value is already a string
case reflect.Bool:
nativeValue, err = strconv.ParseBool(value)
value, err = strconv.ParseBool(rawValue)
case reflect.Struct:
if field.Type == schemas.TimeType {
nativeValue, err = time.Parse(time.RFC3339, value)
nativeValue = nativeValue.(time.Time).In(config.GetTimeZone())
value, err = time.Parse(time.RFC3339, rawValue)
value = value.(time.Time).In(config.GetTimeZone())
}
case reflect.Slice:
t := reflect.SliceOf(schemas.TimeType)
if t != nil {
nativeValue, err = time.Parse(time.RFC3339, value)
nativeValue = nativeValue.(time.Time).In(config.GetTimeZone())
value, err = time.Parse(time.RFC3339, rawValue)
value = value.(time.Time).In(config.GetTimeZone())
return
}
fallthrough
default:
panic(fmt.Errorf("unrecognized filter type %s for field %s, value %s", field.Type.String(), fieldName, value))
panic(fmt.Errorf("unrecognized filter type %s for field %s, value %s", field.Type.String(), field.Name, value))
}
return
}
func getNativeValueForTaskField(fieldName string, comparator taskFilterComparator, value string) (nativeValue interface{}, err error) {
var realFieldName = strcase.ToCamel(fieldName)
if strings.ToLower(fieldName) == "id" {
realFieldName = "ID"
}
field, ok := reflect.TypeOf(&Task{}).Elem().FieldByName(realFieldName)
if !ok {
return nil, ErrInvalidTaskField{TaskField: fieldName}
}
if comparator == taskFilterComparatorIn {
vals := strings.Split(value, ",")
valueSlice := []interface{}{}
for _, val := range vals {
v, err := getValueForField(field, val)
if err != nil {
return nil, err
}
valueSlice = append(valueSlice, v)
}
return valueSlice, nil
}
return getValueForField(field, value)
}

View File

@ -892,6 +892,20 @@ func TestTaskCollection_ReadAll(t *testing.T) {
},
wantErr: false,
},
{
name: "filter in",
fields: fields{
FilterBy: []string{"id"},
FilterValue: []string{"1,2,34"}, // Task 34 is forbidden for user 1
FilterComparator: []string{"in"},
},
args: defaultArgs,
want: []*Task{
task1,
task2,
},
wantErr: false,
},
}
for _, tt := range tests {