Add session handling for list users
This commit is contained in:
parent
ada332e87c
commit
a62d222132
|
@ -18,6 +18,7 @@ package models
|
|||
|
||||
import (
|
||||
"time"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"code.vikunja.io/api/pkg/user"
|
||||
"code.vikunja.io/web"
|
||||
|
@ -71,7 +72,7 @@ type UserWithRight struct {
|
|||
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /lists/{id}/users [put]
|
||||
func (lu *ListUser) Create(a web.Auth) (err error) {
|
||||
func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) {
|
||||
|
||||
// Check if the right is valid
|
||||
if err := lu.Right.isValid(); err != nil {
|
||||
|
@ -85,11 +86,11 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
|
|||
}
|
||||
|
||||
// Check if the user exists
|
||||
user, err := user.GetUserByUsername(lu.Username)
|
||||
u, err := user.GetUserByUsername(lu.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lu.UserID = user.ID
|
||||
lu.UserID = u.ID
|
||||
|
||||
// Check if the user already has access or is owner of that list
|
||||
// We explicitly DONT check for teams here
|
||||
|
@ -97,7 +98,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
|
|||
return ErrUserAlreadyHasAccess{UserID: lu.UserID, ListID: lu.ListID}
|
||||
}
|
||||
|
||||
exist, err := x.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
|
||||
exist, err := s.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -106,7 +107,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
|
|||
}
|
||||
|
||||
// Insert user <-> list relation
|
||||
_, err = x.Insert(lu)
|
||||
_, err = s.Insert(lu)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -128,17 +129,18 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
|
|||
// @Failure 404 {object} web.HTTPError "user or list does not exist."
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /lists/{listID}/users/{userID} [delete]
|
||||
func (lu *ListUser) Delete() (err error) {
|
||||
func (lu *ListUser) Delete(s *xorm.Session) (err error) {
|
||||
|
||||
// Check if the user exists
|
||||
user, err := user.GetUserByUsername(lu.Username)
|
||||
u, err := user.GetUserByUsername(lu.Username)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
lu.UserID = user.ID
|
||||
lu.UserID = u.ID
|
||||
|
||||
// Check if the user has access to the list
|
||||
has, err := x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
|
||||
has, err := s.
|
||||
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
|
||||
Get(&ListUser{})
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -147,7 +149,8 @@ func (lu *ListUser) Delete() (err error) {
|
|||
return ErrUserDoesNotHaveAccessToList{ListID: lu.ListID, UserID: lu.UserID}
|
||||
}
|
||||
|
||||
_, err = x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
|
||||
_, err = s.
|
||||
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
|
||||
Delete(&ListUser{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -172,7 +175,7 @@ func (lu *ListUser) Delete() (err error) {
|
|||
// @Failure 403 {object} web.HTTPError "No right to see the list."
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /lists/{id}/users [get]
|
||||
func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
|
||||
func (lu *ListUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
|
||||
// Check if the user has access to the list
|
||||
l := &List{ID: lu.ListID}
|
||||
canRead, _, err := l.CanRead(s, a)
|
||||
|
@ -187,7 +190,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
|
|||
|
||||
// Get all users
|
||||
all := []*UserWithRight{}
|
||||
query := x.
|
||||
query := s.
|
||||
Join("INNER", "users_list", "user_id = users.id").
|
||||
Where("users_list.list_id = ?", lu.ListID).
|
||||
Where("users.username LIKE ?", "%"+search+"%")
|
||||
|
@ -204,7 +207,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
|
|||
u.Email = ""
|
||||
}
|
||||
|
||||
numberOfTotalItems, err = x.
|
||||
numberOfTotalItems, err = s.
|
||||
Join("INNER", "users_list", "user_id = users.id").
|
||||
Where("users_list.list_id = ?", lu.ListID).
|
||||
Where("users.username LIKE ?", "%"+search+"%").
|
||||
|
@ -228,7 +231,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
|
|||
// @Failure 404 {object} web.HTTPError "User or list does not exist."
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /lists/{listID}/users/{userID} [post]
|
||||
func (lu *ListUser) Update() (err error) {
|
||||
func (lu *ListUser) Update(s *xorm.Session) (err error) {
|
||||
|
||||
// Check if the right is valid
|
||||
if err := lu.Right.isValid(); err != nil {
|
||||
|
@ -242,7 +245,7 @@ func (lu *ListUser) Update() (err error) {
|
|||
}
|
||||
lu.UserID = u.ID
|
||||
|
||||
_, err = x.
|
||||
_, err = s.
|
||||
Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).
|
||||
Cols("right").
|
||||
Update(lu)
|
||||
|
|
|
@ -18,24 +18,25 @@ package models
|
|||
|
||||
import (
|
||||
"code.vikunja.io/web"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
// CanCreate checks if the user can create a new user <-> list relation
|
||||
func (lu *ListUser) CanCreate(a web.Auth) (bool, error) {
|
||||
return lu.canDoListUser(a)
|
||||
func (lu *ListUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
|
||||
return lu.canDoListUser(s, a)
|
||||
}
|
||||
|
||||
// CanDelete checks if the user can delete a user <-> list relation
|
||||
func (lu *ListUser) CanDelete(a web.Auth) (bool, error) {
|
||||
return lu.canDoListUser(a)
|
||||
func (lu *ListUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
|
||||
return lu.canDoListUser(s, a)
|
||||
}
|
||||
|
||||
// CanUpdate checks if the user can update a user <-> list relation
|
||||
func (lu *ListUser) CanUpdate(a web.Auth) (bool, error) {
|
||||
return lu.canDoListUser(a)
|
||||
func (lu *ListUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
|
||||
return lu.canDoListUser(s, a)
|
||||
}
|
||||
|
||||
func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) {
|
||||
func (lu *ListUser) canDoListUser(s *xorm.Session, a web.Auth) (bool, error) {
|
||||
// Link shares aren't allowed to do anything
|
||||
if _, is := a.(*LinkSharing); is {
|
||||
return false, nil
|
||||
|
|
|
@ -80,6 +80,7 @@ func TestListUser_CanDoSomething(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := x.NewSession()
|
||||
|
||||
lu := &ListUser{
|
||||
ID: tt.fields.ID,
|
||||
|
@ -91,15 +92,16 @@ func TestListUser_CanDoSomething(t *testing.T) {
|
|||
CRUDable: tt.fields.CRUDable,
|
||||
Rights: tt.fields.Rights,
|
||||
}
|
||||
if got, _ := lu.CanCreate(tt.args.a); got != tt.want["CanCreate"] {
|
||||
if got, _ := lu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
|
||||
t.Errorf("ListUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
|
||||
}
|
||||
if got, _ := lu.CanDelete(tt.args.a); got != tt.want["CanDelete"] {
|
||||
if got, _ := lu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
|
||||
t.Errorf("ListUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
|
||||
}
|
||||
if got, _ := lu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] {
|
||||
if got, _ := lu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
|
||||
t.Errorf("ListUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
|
||||
}
|
||||
_ = s.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
@ -108,6 +109,7 @@ func TestListUser_Create(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := x.NewSession()
|
||||
|
||||
ul := &ListUser{
|
||||
ID: tt.fields.ID,
|
||||
|
@ -120,13 +122,17 @@ func TestListUser_Create(t *testing.T) {
|
|||
CRUDable: tt.fields.CRUDable,
|
||||
Rights: tt.fields.Rights,
|
||||
}
|
||||
err := ul.Create(tt.args.a)
|
||||
err := ul.Create(s, tt.args.a)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ListUser.Create() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if (err != nil) && tt.wantErr && !tt.errType(err) {
|
||||
t.Errorf("ListUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
|
||||
}
|
||||
|
||||
err = s.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !tt.wantErr {
|
||||
db.AssertExists(t, "users_list", map[string]interface{}{
|
||||
"user_id": ul.UserID,
|
||||
|
@ -212,6 +218,7 @@ func TestListUser_ReadAll(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := x.NewSession()
|
||||
|
||||
ul := &ListUser{
|
||||
ID: tt.fields.ID,
|
||||
|
@ -223,7 +230,7 @@ func TestListUser_ReadAll(t *testing.T) {
|
|||
CRUDable: tt.fields.CRUDable,
|
||||
Rights: tt.fields.Rights,
|
||||
}
|
||||
got, _, _, err := ul.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50)
|
||||
got, _, _, err := ul.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ListUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
@ -233,6 +240,7 @@ func TestListUser_ReadAll(t *testing.T) {
|
|||
if diff, equal := messagediff.PrettyDiff(got, tt.want); !equal {
|
||||
t.Errorf("ListUser.ReadAll() = %v, want %v, diff: %v", got, tt.want, diff)
|
||||
}
|
||||
_ = s.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -292,6 +300,7 @@ func TestListUser_Update(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := x.NewSession()
|
||||
|
||||
lu := &ListUser{
|
||||
ID: tt.fields.ID,
|
||||
|
@ -303,13 +312,17 @@ func TestListUser_Update(t *testing.T) {
|
|||
CRUDable: tt.fields.CRUDable,
|
||||
Rights: tt.fields.Rights,
|
||||
}
|
||||
err := lu.Update()
|
||||
err := lu.Update(s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ListUser.Update() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if (err != nil) && tt.wantErr && !tt.errType(err) {
|
||||
t.Errorf("ListUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
|
||||
}
|
||||
|
||||
err = s.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !tt.wantErr {
|
||||
db.AssertExists(t, "users_list", map[string]interface{}{
|
||||
"list_id": tt.fields.ListID,
|
||||
|
@ -369,6 +382,7 @@ func TestListUser_Delete(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := x.NewSession()
|
||||
|
||||
lu := &ListUser{
|
||||
ID: tt.fields.ID,
|
||||
|
@ -380,13 +394,17 @@ func TestListUser_Delete(t *testing.T) {
|
|||
CRUDable: tt.fields.CRUDable,
|
||||
Rights: tt.fields.Rights,
|
||||
}
|
||||
err := lu.Delete()
|
||||
err := lu.Delete(s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ListUser.Delete() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if (err != nil) && tt.wantErr && !tt.errType(err) {
|
||||
t.Errorf("ListUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
|
||||
}
|
||||
|
||||
err = s.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !tt.wantErr {
|
||||
db.AssertMissing(t, "users_list", map[string]interface{}{
|
||||
"user_id": tt.fields.UserID,
|
||||
|
|
Loading…
Reference in New Issue