Add session handling for list users

This commit is contained in:
kolaente 2020-12-22 20:46:42 +01:00
parent ada332e87c
commit a62d222132
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
4 changed files with 53 additions and 29 deletions

View File

@ -18,6 +18,7 @@ package models
import (
"time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
@ -71,7 +72,7 @@ type UserWithRight struct {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/users [put]
func (lu *ListUser) Create(a web.Auth) (err error) {
func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the right is valid
if err := lu.Right.isValid(); err != nil {
@ -85,11 +86,11 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
}
// Check if the user exists
user, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(lu.Username)
if err != nil {
return err
}
lu.UserID = user.ID
lu.UserID = u.ID
// Check if the user already has access or is owner of that list
// We explicitly DONT check for teams here
@ -97,7 +98,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
return ErrUserAlreadyHasAccess{UserID: lu.UserID, ListID: lu.ListID}
}
exist, err := x.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
exist, err := s.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
if err != nil {
return
}
@ -106,7 +107,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
}
// Insert user <-> list relation
_, err = x.Insert(lu)
_, err = s.Insert(lu)
if err != nil {
return err
}
@ -128,17 +129,18 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "user or list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/users/{userID} [delete]
func (lu *ListUser) Delete() (err error) {
func (lu *ListUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists
user, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(lu.Username)
if err != nil {
return
}
lu.UserID = user.ID
lu.UserID = u.ID
// Check if the user has access to the list
has, err := x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
has, err := s.
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
Get(&ListUser{})
if err != nil {
return
@ -147,7 +149,8 @@ func (lu *ListUser) Delete() (err error) {
return ErrUserDoesNotHaveAccessToList{ListID: lu.ListID, UserID: lu.UserID}
}
_, err = x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
_, err = s.
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
Delete(&ListUser{})
if err != nil {
return err
@ -172,7 +175,7 @@ func (lu *ListUser) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the list."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/users [get]
func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (lu *ListUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the list
l := &List{ID: lu.ListID}
canRead, _, err := l.CanRead(s, a)
@ -187,7 +190,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all users
all := []*UserWithRight{}
query := x.
query := s.
Join("INNER", "users_list", "user_id = users.id").
Where("users_list.list_id = ?", lu.ListID).
Where("users.username LIKE ?", "%"+search+"%")
@ -204,7 +207,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
u.Email = ""
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Join("INNER", "users_list", "user_id = users.id").
Where("users_list.list_id = ?", lu.ListID).
Where("users.username LIKE ?", "%"+search+"%").
@ -228,7 +231,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "User or list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/users/{userID} [post]
func (lu *ListUser) Update() (err error) {
func (lu *ListUser) Update(s *xorm.Session) (err error) {
// Check if the right is valid
if err := lu.Right.isValid(); err != nil {
@ -242,7 +245,7 @@ func (lu *ListUser) Update() (err error) {
}
lu.UserID = u.ID
_, err = x.
_, err = s.
Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).
Cols("right").
Update(lu)

View File

@ -18,24 +18,25 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if the user can create a new user <-> list relation
func (lu *ListUser) CanCreate(a web.Auth) (bool, error) {
return lu.canDoListUser(a)
func (lu *ListUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(s, a)
}
// CanDelete checks if the user can delete a user <-> list relation
func (lu *ListUser) CanDelete(a web.Auth) (bool, error) {
return lu.canDoListUser(a)
func (lu *ListUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(s, a)
}
// CanUpdate checks if the user can update a user <-> list relation
func (lu *ListUser) CanUpdate(a web.Auth) (bool, error) {
return lu.canDoListUser(a)
func (lu *ListUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(s, a)
}
func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) {
func (lu *ListUser) canDoListUser(s *xorm.Session, a web.Auth) (bool, error) {
// Link shares aren't allowed to do anything
if _, is := a.(*LinkSharing); is {
return false, nil

View File

@ -80,6 +80,7 @@ func TestListUser_CanDoSomething(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := x.NewSession()
lu := &ListUser{
ID: tt.fields.ID,
@ -91,15 +92,16 @@ func TestListUser_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
if got, _ := lu.CanCreate(tt.args.a); got != tt.want["CanCreate"] {
if got, _ := lu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("ListUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
}
if got, _ := lu.CanDelete(tt.args.a); got != tt.want["CanDelete"] {
if got, _ := lu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("ListUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
}
if got, _ := lu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] {
if got, _ := lu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("ListUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
}
_ = s.Close()
})
}
}

View File

@ -17,6 +17,7 @@
package models
import (
"github.com/stretchr/testify/assert"
"reflect"
"runtime"
"testing"
@ -108,6 +109,7 @@ func TestListUser_Create(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := x.NewSession()
ul := &ListUser{
ID: tt.fields.ID,
@ -120,13 +122,17 @@ func TestListUser_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := ul.Create(tt.args.a)
err := ul.Create(s, tt.args.a)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Create() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "users_list", map[string]interface{}{
"user_id": ul.UserID,
@ -212,6 +218,7 @@ func TestListUser_ReadAll(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := x.NewSession()
ul := &ListUser{
ID: tt.fields.ID,
@ -223,7 +230,7 @@ func TestListUser_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
got, _, _, err := ul.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50)
got, _, _, err := ul.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
}
@ -233,6 +240,7 @@ func TestListUser_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(got, tt.want); !equal {
t.Errorf("ListUser.ReadAll() = %v, want %v, diff: %v", got, tt.want, diff)
}
_ = s.Close()
})
}
}
@ -292,6 +300,7 @@ func TestListUser_Update(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := x.NewSession()
lu := &ListUser{
ID: tt.fields.ID,
@ -303,13 +312,17 @@ func TestListUser_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := lu.Update()
err := lu.Update(s)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "users_list", map[string]interface{}{
"list_id": tt.fields.ListID,
@ -369,6 +382,7 @@ func TestListUser_Delete(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := x.NewSession()
lu := &ListUser{
ID: tt.fields.ID,
@ -380,13 +394,17 @@ func TestListUser_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := lu.Delete()
err := lu.Delete(s)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertMissing(t, "users_list", map[string]interface{}{
"user_id": tt.fields.UserID,