From e4539ef2328c2724bd33db6e26fdecd1d1a28164 Mon Sep 17 00:00:00 2001 From: konrad Date: Sat, 1 Aug 2020 16:54:38 +0000 Subject: [PATCH] Use db sessions for task-related things (#621) Use db sessions for task-related things Co-authored-by: kolaente Reviewed-on: https://kolaente.dev/vikunja/api/pulls/621 --- pkg/models/bulk_task.go | 5 ++- pkg/models/kanban.go | 27 ++++++++----- pkg/models/kanban_rights.go | 2 +- pkg/models/list.go | 26 +++++++++++-- pkg/models/list_duplicate.go | 4 +- pkg/models/task_assignees.go | 17 +++++++-- pkg/models/tasks.go | 73 ++++++++++++++++++++++++------------ 7 files changed, 109 insertions(+), 45 deletions(-) diff --git a/pkg/models/bulk_task.go b/pkg/models/bulk_task.go index ff7b0a7830..d5541ff907 100644 --- a/pkg/models/bulk_task.go +++ b/pkg/models/bulk_task.go @@ -93,7 +93,7 @@ func (bt *BulkTask) Update() (err error) { updateDone(oldtask, &bt.Task) // Update the assignees - if err := oldtask.updateTaskAssignees(bt.Assignees); err != nil { + if err := oldtask.updateTaskAssignees(sess, bt.Assignees); err != nil { return err } @@ -121,7 +121,8 @@ func (bt *BulkTask) Update() (err error) { "end_date"). Update(oldtask) if err != nil { - return sess.Rollback() + _ = sess.Rollback() + return err } } diff --git a/pkg/models/kanban.go b/pkg/models/kanban.go index aaaecda303..6dac185ffd 100644 --- a/pkg/models/kanban.go +++ b/pkg/models/kanban.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" "time" + "xorm.io/xorm" ) // Bucket represents a kanban bucket @@ -52,9 +53,9 @@ func (b *Bucket) TableName() string { return "buckets" } -func getBucketByID(id int64) (b *Bucket, err error) { +func getBucketByID(s *xorm.Session, id int64) (b *Bucket, err error) { b = &Bucket{} - exists, err := x.Where("id = ?", id).Get(b) + exists, err := s.Where("id = ?", id).Get(b) if err != nil { return } @@ -64,9 +65,9 @@ func getBucketByID(id int64) (b *Bucket, err error) { return } -func getDefaultBucket(listID int64) (bucket *Bucket, err error) { +func getDefaultBucket(s *xorm.Session, listID int64) (bucket *Bucket, err error) { bucket = &Bucket{} - _, err = x. + _, err = s. Where("list_id = ?", listID). OrderBy("id asc"). Get(bucket) @@ -199,9 +200,13 @@ func (b *Bucket) Update() (err error) { // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/buckets/{bucketID} [delete] func (b *Bucket) Delete() (err error) { + + s := x.NewSession() + // Prevent removing the last bucket - total, err := x.Where("list_id = ?", b.ListID).Count(&Bucket{}) + total, err := s.Where("list_id = ?", b.ListID).Count(&Bucket{}) if err != nil { + _ = s.Rollback() return } if total <= 1 { @@ -212,21 +217,25 @@ func (b *Bucket) Delete() (err error) { } // Remove the bucket itself - _, err = x.Where("id = ?", b.ID).Delete(&Bucket{}) + _, err = s.Where("id = ?", b.ID).Delete(&Bucket{}) if err != nil { + _ = s.Rollback() return } // Get the default bucket - defaultBucket, err := getDefaultBucket(b.ListID) + defaultBucket, err := getDefaultBucket(s, b.ListID) if err != nil { + _ = s.Rollback() return } // Remove all associations of tasks to that bucket - _, err = x.Where("bucket_id = ?", b.ID).Cols("bucket_id").Update(&Task{BucketID: defaultBucket.ID}) + _, err = s.Where("bucket_id = ?", b.ID).Cols("bucket_id").Update(&Task{BucketID: defaultBucket.ID}) if err != nil { + _ = s.Rollback() return } - return + + return s.Commit() } diff --git a/pkg/models/kanban_rights.go b/pkg/models/kanban_rights.go index cdf386812f..83881f1e6e 100644 --- a/pkg/models/kanban_rights.go +++ b/pkg/models/kanban_rights.go @@ -36,7 +36,7 @@ func (b *Bucket) CanDelete(a web.Auth) (bool, error) { // canDoBucket checks if the bucket exists and if the user has the right to act on it func (b *Bucket) canDoBucket(a web.Auth) (bool, error) { - bb, err := getBucketByID(b.ID) + bb, err := getBucketByID(x.NewSession(), b.ID) if err != nil { return false, err } diff --git a/pkg/models/list.go b/pkg/models/list.go index 6d44889a4e..c06c9c0975 100644 --- a/pkg/models/list.go +++ b/pkg/models/list.go @@ -199,6 +199,16 @@ func (l *List) ReadOne() (err error) { // GetSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist. func (l *List) GetSimpleByID() (err error) { + s := x.NewSession() + err = l.getSimpleByID(s) + if err != nil { + _ = s.Rollback() + return err + } + return nil +} + +func (l *List) getSimpleByID(s *xorm.Session) (err error) { if l.ID < 1 { return ErrListDoesNotExist{ID: l.ID} } @@ -207,7 +217,7 @@ func (l *List) GetSimpleByID() (err error) { // leading to not finding anything if the id is good, but for example the title is different. id := l.ID *l = List{} - exists, err := x.Where("id = ?", id).Get(l) + exists, err := s.Where("id = ?", id).Get(l) if err != nil { return } @@ -520,8 +530,18 @@ func (l *List) Update() (err error) { return CreateOrUpdateList(l) } -func updateListLastUpdated(list *List) error { - _, err := x.ID(list.ID).Cols("updated").Update(list) +func updateListLastUpdated(list *List) (err error) { + s := x.NewSession() + err = updateListLastUpdatedS(s, list) + if err != nil { + _ = s.Rollback() + return err + } + return nil +} + +func updateListLastUpdatedS(s *xorm.Session, list *List) error { + _, err := s.ID(list.ID).Cols("updated").Update(list) return err } diff --git a/pkg/models/list_duplicate.go b/pkg/models/list_duplicate.go index fbb2176029..7a70fd9129 100644 --- a/pkg/models/list_duplicate.go +++ b/pkg/models/list_duplicate.go @@ -120,8 +120,10 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { t.ListID = ld.List.ID t.BucketID = bucketMap[t.BucketID] t.UID = "" - err := createTask(t, a, false) + s := x.NewSession() + err := createTask(s, t, a, false) if err != nil { + _ = s.Rollback() return err } taskMap[oldID] = t.ID diff --git a/pkg/models/task_assignees.go b/pkg/models/task_assignees.go index 4853a336f7..f762d7b6d2 100644 --- a/pkg/models/task_assignees.go +++ b/pkg/models/task_assignees.go @@ -20,6 +20,7 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" "time" + "xorm.io/xorm" ) // TaskAssginee represents an assignment of a user to a task @@ -55,7 +56,7 @@ func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssignee } // Create or update a bunch of task assignees -func (t *Task) updateTaskAssignees(assignees []*user.User) (err error) { +func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err error) { // Load the current assignees currentAssignees, err := getRawTaskAssigneesForTasks([]int64{t.ID}) @@ -70,7 +71,7 @@ func (t *Task) updateTaskAssignees(assignees []*user.User) (err error) { // If we don't have any new assignees, delete everything right away. Saves us some hassle. if len(assignees) == 0 && len(t.Assignees) > 0 { - _, err = x.Where("task_id = ?", t.ID). + _, err = s.Where("task_id = ?", t.ID). Delete(TaskAssginee{}) t.setTaskAssignees(assignees) return err @@ -107,7 +108,7 @@ func (t *Task) updateTaskAssignees(assignees []*user.User) (err error) { // Delete all assignees not passed if len(assigneesToDelete) > 0 { - _, err = x.In("user_id", assigneesToDelete). + _, err = s.In("user_id", assigneesToDelete). And("task_id = ?", t.ID). Delete(TaskAssginee{}) if err != nil { @@ -300,6 +301,8 @@ type BulkAssignees struct { // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/assignees/bulk [post] func (ba *BulkAssignees) Create(a web.Auth) (err error) { + s := x.NewSession() + task, err := GetTaskByIDSimple(ba.TaskID) if err != nil { return @@ -312,5 +315,11 @@ func (ba *BulkAssignees) Create(a web.Auth) (err error) { task.Assignees = append(task.Assignees, &a.User) } - return task.updateTaskAssignees(ba.Assignees) + err = task.updateTaskAssignees(s, ba.Assignees) + if err != nil { + _ = s.Rollback() + return err + } + + return s.Commit() } diff --git a/pkg/models/tasks.go b/pkg/models/tasks.go index 4808105600..5b59472f85 100644 --- a/pkg/models/tasks.go +++ b/pkg/models/tasks.go @@ -28,6 +28,7 @@ import ( "strconv" "time" "xorm.io/builder" + "xorm.io/xorm" "xorm.io/xorm/schemas" ) @@ -525,9 +526,9 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) { return } -func checkBucketAndTaskBelongToSameList(fullTask *Task, bucketID int64) (err error) { +func checkBucketAndTaskBelongToSameList(s *xorm.Session, fullTask *Task, bucketID int64) (err error) { if bucketID != 0 { - b, err := getBucketByID(bucketID) + b, err := getBucketByID(s, bucketID) if err != nil { return err } @@ -556,10 +557,16 @@ func checkBucketAndTaskBelongToSameList(fullTask *Task, bucketID int64) (err err // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id} [put] func (t *Task) Create(a web.Auth) (err error) { - return createTask(t, a, true) + s := x.NewSession() + err = createTask(s, t, a, true) + if err != nil { + _ = s.Rollback() + return err + } + return s.Commit() } -func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { +func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err error) { t.ID = 0 @@ -570,7 +577,7 @@ func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { // Check if the list exists l := &List{ID: t.ListID} - if err = l.GetSimpleByID(); err != nil { + if err = l.getSimpleByID(s); err != nil { return } @@ -592,14 +599,14 @@ func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { } // If there is a bucket set, make sure they belong to the same list as the task - err = checkBucketAndTaskBelongToSameList(t, t.BucketID) + err = checkBucketAndTaskBelongToSameList(s, t, t.BucketID) if err != nil { return } // Get the default bucket and move the task there if t.BucketID == 0 { - defaultBucket, err := getDefaultBucket(t.ListID) + defaultBucket, err := getDefaultBucket(s, t.ListID) if err != nil { return err } @@ -608,7 +615,7 @@ func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { // Get the index for this task latestTask := &Task{} - _, err = x.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask) + _, err = s.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask) if err != nil { return err } @@ -618,19 +625,19 @@ func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { if t.Position == 0 { t.Position = float64(latestTask.ID+1) * math.Pow(2, 16) } - if _, err = x.Insert(t); err != nil { + if _, err = s.Insert(t); err != nil { return err } // Update the assignees if updateAssignees { - if err := t.updateTaskAssignees(t.Assignees); err != nil { + if err := t.updateTaskAssignees(s, t.Assignees); err != nil { return err } } // Update the reminders - if err := t.updateReminders(t.Reminders); err != nil { + if err := t.updateReminders(s, t.Reminders); err != nil { return err } @@ -638,7 +645,7 @@ func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { t.setIdentifier(l) - err = updateListLastUpdated(&List{ID: t.ListID}) + err = updateListLastUpdatedS(s, &List{ID: t.ListID}) return } @@ -657,15 +664,20 @@ func createTask(t *Task, a web.Auth, updateAssignees bool) (err error) { // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{id} [post] func (t *Task) Update() (err error) { + + s := x.NewSession() + // Check if the task exists and get the old values ot, err := GetTaskByIDSimple(t.ID) if err != nil { + _ = s.Rollback() return } // Get the reminders reminders, err := getRemindersForTasks([]int64{t.ID}) if err != nil { + _ = s.Rollback() return } @@ -678,18 +690,21 @@ func (t *Task) Update() (err error) { updateDone(&ot, t) // Update the assignees - if err := ot.updateTaskAssignees(t.Assignees); err != nil { + if err := ot.updateTaskAssignees(s, t.Assignees); err != nil { + _ = s.Rollback() return err } // Update the reminders - if err := ot.updateReminders(t.Reminders); err != nil { + if err := ot.updateReminders(s, t.Reminders); err != nil { + _ = s.Rollback() return err } // If there is a bucket set, make sure they belong to the same list as the task - err = checkBucketAndTaskBelongToSameList(&ot, t.BucketID) + err = checkBucketAndTaskBelongToSameList(s, &ot, t.BucketID) if err != nil { + _ = s.Rollback() return } @@ -714,15 +729,17 @@ func (t *Task) Update() (err error) { // If the task is being moved between lists, make sure to move the bucket + index as well if t.ListID != 0 && ot.ListID != t.ListID { - b, err := getDefaultBucket(t.ListID) + b, err := getDefaultBucket(s, t.ListID) if err != nil { + _ = s.Rollback() return err } t.BucketID = b.ID latestTask := &Task{} - _, err = x.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask) + _, err = s.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask) if err != nil { + _ = s.Rollback() return err } @@ -751,6 +768,7 @@ func (t *Task) Update() (err error) { // Which is why we merge the actual task struct with the one we got from the db // The user struct overrides values in the actual one. if err := mergo.Merge(&ot, t, mergo.WithOverride); err != nil { + _ = s.Rollback() return err } @@ -803,16 +821,21 @@ func (t *Task) Update() (err error) { ot.RepeatFromCurrentDate = false } - _, err = x.ID(t.ID). + _, err = s.ID(t.ID). Cols(colsToUpdate...). Update(ot) *t = ot if err != nil { + _ = s.Rollback() return err } - err = updateListLastUpdated(&List{ID: t.ListID}) - return + err = updateListLastUpdatedS(s, &List{ID: t.ListID}) + if err != nil { + _ = s.Rollback() + return err + } + return s.Commit() } // This helper function updates the reminders, doneAt, start and end dates of the *old* task @@ -910,7 +933,7 @@ func updateDone(oldTask *Task, newTask *Task) { // Creates or deletes all necessary reminders without unneded db operations. // The parameter is a slice with unix dates which holds the new reminders. -func (t *Task) updateReminders(reminders []time.Time) (err error) { +func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err error) { // Load the current reminders taskReminders, err := getRemindersForTasks([]int64{t.ID}) @@ -925,7 +948,7 @@ func (t *Task) updateReminders(reminders []time.Time) (err error) { // If we're removing everything, delete all reminders right away if len(reminders) == 0 && len(t.Reminders) > 0 { - _, err = x.Where("task_id = ?", t.ID). + _, err = s.Where("task_id = ?", t.ID). Delete(TaskReminder{}) t.Reminders = nil return err @@ -963,7 +986,7 @@ func (t *Task) updateReminders(reminders []time.Time) (err error) { // Delete all reminders not passed if len(remindersToDelete) > 0 { - _, err = x.In("reminder", remindersToDelete). + _, err = s.In("reminder", remindersToDelete). And("task_id = ?", t.ID). Delete(TaskReminder{}) if err != nil { @@ -980,7 +1003,7 @@ func (t *Task) updateReminders(reminders []time.Time) (err error) { } // Add the new reminder - _, err = x.Insert(TaskReminder{TaskID: t.ID, Reminder: r}) + _, err = s.Insert(TaskReminder{TaskID: t.ID, Reminder: r}) if err != nil { return err } @@ -991,7 +1014,7 @@ func (t *Task) updateReminders(reminders []time.Time) (err error) { t.Reminders = nil } - err = updateListLastUpdated(&List{ID: t.ListID}) + err = updateListLastUpdatedS(s, &List{ID: t.ListID}) return }