Use a session for everything in users

This commit is contained in:
kolaente 2020-12-23 01:32:41 +01:00
parent 1cab1f63af
commit 7088b64dc8
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
5 changed files with 185 additions and 81 deletions

View File

@ -201,7 +201,7 @@ func TestListUsersFromList(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := x.NewSession()
s := db.NewSession()
defer s.Close()
gotUsers, err := ListUsersFromList(s, tt.args.l, tt.args.search)

View File

@ -41,7 +41,7 @@ import (
// @Router /users [get]
func UserList(c echo.Context) error {
s := c.QueryParam("s")
users, err := user.ListUsers(s)
users, err := user.ListUsers(s, s)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -80,7 +80,7 @@ func ListUsersForList(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
canRead, _, err := list.CanRead(auth)
canRead, _, err := list.CanRead(nil, auth)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -89,7 +89,7 @@ func ListUsersForList(c echo.Context) error {
}
s := c.QueryParam("s")
users, err := models.ListUsersFromList(&list, s)
users, err := models.ListUsersFromList(s, &list, s)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -22,6 +22,7 @@ import (
"fmt"
"reflect"
"time"
"xorm.io/xorm"
"code.vikunja.io/web"
"github.com/dgrijalva/jwt-go"
@ -116,38 +117,33 @@ func (apiUser *APIUserPassword) APIFormat() *User {
}
// GetUserByID gets informations about a user by its ID
func GetUserByID(id int64) (user *User, err error) {
func GetUserByID(s *xorm.Session, id int64) (user *User, err error) {
// Apparently xorm does otherwise look for all users but return only one, which leads to returing one even if the ID is 0
if id < 1 {
return &User{}, ErrUserDoesNotExist{}
}
return GetUser(&User{ID: id})
return getUser(s, &User{ID: id}, false)
}
// GetUserByUsername gets a user from its user name. This is an extra function to be able to add an extra error check.
func GetUserByUsername(username string) (user *User, err error) {
func GetUserByUsername(s *xorm.Session, username string) (user *User, err error) {
if username == "" {
return &User{}, ErrUserDoesNotExist{}
}
return GetUser(&User{Username: username})
}
// GetUser gets a user object
func GetUser(user *User) (userOut *User, err error) {
return getUser(user, false)
return getUser(s, &User{Username: username}, false)
}
// GetUserWithEmail returns a user object with email
func GetUserWithEmail(user *User) (userOut *User, err error) {
return getUser(user, true)
func GetUserWithEmail(s *xorm.Session, user *User) (userOut *User, err error) {
return getUser(s, user, true)
}
// GetUsersByIDs returns a map of users from a slice of user ids
func GetUsersByIDs(userIDs []int64) (users map[int64]*User, err error) {
func GetUsersByIDs(s *xorm.Session, userIDs []int64) (users map[int64]*User, err error) {
users = make(map[int64]*User)
err = x.In("id", userIDs).Find(&users)
err = s.In("id", userIDs).Find(&users)
if err != nil {
return
}
@ -161,10 +157,10 @@ func GetUsersByIDs(userIDs []int64) (users map[int64]*User, err error) {
}
// getUser is a small helper function to avoid having duplicated code for almost the same use case
func getUser(user *User, withEmail bool) (userOut *User, err error) {
func getUser(s *xorm.Session, user *User, withEmail bool) (userOut *User, err error) {
userOut = &User{} // To prevent a panic if user is nil
*userOut = *user
exists, err := x.Get(userOut)
exists, err := s.Get(userOut)
if err != nil {
return nil, err
}
@ -179,9 +175,9 @@ func getUser(user *User, withEmail bool) (userOut *User, err error) {
return userOut, err
}
func getUserByUsernameOrEmail(usernameOrEmail string) (u *User, err error) {
func getUserByUsernameOrEmail(s *xorm.Session, usernameOrEmail string) (u *User, err error) {
u = &User{}
exists, err := x.
exists, err := s.
Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail).
Get(u)
if err != nil {
@ -196,14 +192,14 @@ func getUserByUsernameOrEmail(usernameOrEmail string) (u *User, err error) {
}
// CheckUserCredentials checks user credentials
func CheckUserCredentials(u *Login) (*User, error) {
func CheckUserCredentials(s *xorm.Session, u *Login) (*User, error) {
// Check if we have any credentials
if u.Password == "" || u.Username == "" {
return nil, ErrNoUsernamePassword{}
}
// Check if the user exists
user, err := getUserByUsernameOrEmail(u.Username)
user, err := getUserByUsernameOrEmail(s, u.Username)
if err != nil {
// hashing the password takes a long time, so we hash something to not make it clear if the username was wrong
_, _ = bcrypt.GenerateFromPassword([]byte(u.Username), 14)
@ -261,10 +257,10 @@ func GetUserFromClaims(claims jwt.MapClaims) (user *User, err error) {
}
// UpdateUser updates a user
func UpdateUser(user *User) (updatedUser *User, err error) {
func UpdateUser(s *xorm.Session, user *User) (updatedUser *User, err error) {
// Check if it exists
theUser, err := GetUserWithEmail(&User{ID: user.ID})
theUser, err := GetUserWithEmail(s, &User{ID: user.ID})
if err != nil {
return &User{}, err
}
@ -274,7 +270,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) {
user.Username = theUser.Username // Dont change the username if we dont have one
} else {
// Check if the new username already exists
uu, err := GetUserByUsername(user.Username)
uu, err := GetUserByUsername(s, user.Username)
if err != nil && !IsErrUserDoesNotExist(err) {
return nil, err
}
@ -292,7 +288,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) {
if user.Email == "" {
user.Email = theUser.Email
} else {
uu, err := getUser(&User{
uu, err := getUser(s, &User{
Email: user.Email,
Issuer: user.Issuer,
Subject: user.Subject,
@ -316,7 +312,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) {
}
// Update it
_, err = x.
_, err = s.
ID(user.ID).
Cols(
"username",
@ -333,7 +329,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) {
}
// Get the newly updated user
updatedUser, err = GetUserByID(user.ID)
updatedUser, err = GetUserByID(s, user.ID)
if err != nil {
return &User{}, err
}
@ -342,14 +338,14 @@ func UpdateUser(user *User) (updatedUser *User, err error) {
}
// UpdateUserPassword updates the password of a user
func UpdateUserPassword(user *User, newPassword string) (err error) {
func UpdateUserPassword(s *xorm.Session, user *User, newPassword string) (err error) {
if newPassword == "" {
return ErrEmptyNewPassword{}
}
// Get all user details
theUser, err := GetUserByID(user.ID)
theUser, err := GetUserByID(s, user.ID)
if err != nil {
return err
}
@ -362,7 +358,7 @@ func UpdateUserPassword(user *User, newPassword string) (err error) {
theUser.Password = hashed
// Update it
_, err = x.ID(user.ID).Update(theUser)
_, err = s.ID(user.ID).Update(theUser)
if err != nil {
return err
}

View File

@ -22,12 +22,13 @@ import (
"code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/utils"
"golang.org/x/crypto/bcrypt"
"xorm.io/xorm"
)
const issuerLocal = `local`
// CreateUser creates a new user and inserts it into the database
func CreateUser(user *User) (newUser *User, err error) {
func CreateUser(s *xorm.Session, user *User) (newUser *User, err error) {
if user.Issuer == "" {
user.Issuer = issuerLocal
@ -40,7 +41,7 @@ func CreateUser(user *User) (newUser *User, err error) {
}
// Check if the user already exists with that username
err = checkIfUserExists(user)
err = checkIfUserExists(s, user)
if err != nil {
return nil, err
}
@ -64,7 +65,7 @@ func CreateUser(user *User) (newUser *User, err error) {
user.AvatarProvider = "initials"
// Insert it
_, err = x.Insert(user)
_, err = s.Insert(user)
if err != nil {
return nil, err
}
@ -73,7 +74,7 @@ func CreateUser(user *User) (newUser *User, err error) {
metrics.UpdateCount(1, metrics.ActiveUsersKey)
// Get the full new User
newUserOut, err := GetUserByID(user.ID)
newUserOut, err := GetUserByID(s, user.ID)
if err != nil {
return nil, err
}
@ -100,9 +101,9 @@ func checkIfUserIsValid(user *User) error {
return nil
}
func checkIfUserExists(user *User) (err error) {
func checkIfUserExists(s *xorm.Session, user *User) (err error) {
exists := true
_, err = GetUserByUsername(user.Username)
_, err = GetUserByUsername(s, user.Username)
if err != nil {
if IsErrUserDoesNotExist(err) {
exists = false
@ -126,7 +127,7 @@ func checkIfUserExists(user *User) (err error) {
userToCheck.Email = ""
}
_, err = GetUser(userToCheck)
_, err = getUser(s, userToCheck, false)
if err != nil {
if IsErrUserDoesNotExist(err) {
exists = false

View File

@ -34,13 +34,19 @@ func TestCreateUser(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
createdUser, err := CreateUser(dummyuser)
s := db.NewSession()
defer s.Close()
createdUser, err := CreateUser(s, dummyuser)
assert.NoError(t, err)
assert.NotZero(t, createdUser.Created)
})
t.Run("already existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "user1",
Password: "12345",
Email: "email@example.com",
@ -50,7 +56,10 @@ func TestCreateUser(t *testing.T) {
})
t.Run("same email", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "testuser",
Password: "12345",
Email: "user1@example.com",
@ -60,7 +69,10 @@ func TestCreateUser(t *testing.T) {
})
t.Run("no username", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "",
Password: "12345",
Email: "user1@example.com",
@ -70,7 +82,10 @@ func TestCreateUser(t *testing.T) {
})
t.Run("no password", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "testuser",
Password: "",
Email: "user1@example.com",
@ -80,7 +95,10 @@ func TestCreateUser(t *testing.T) {
})
t.Run("no email", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "testuser",
Password: "12345",
Email: "",
@ -90,7 +108,10 @@ func TestCreateUser(t *testing.T) {
})
t.Run("same email but different issuer", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "somenewuser",
Email: "user1@example.com",
Issuer: "https://some.site",
@ -100,7 +121,10 @@ func TestCreateUser(t *testing.T) {
})
t.Run("same subject but different issuer", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CreateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := CreateUser(s, &User{
Username: "somenewuser",
Email: "somenewuser@example.com",
Issuer: "https://some.site",
@ -113,25 +137,41 @@ func TestCreateUser(t *testing.T) {
func TestGetUser(t *testing.T) {
t.Run("by name", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
theuser, err := GetUser(&User{
Username: "user1",
})
s := db.NewSession()
defer s.Close()
theuser, err := getUser(
s,
&User{
Username: "user1",
},
false,
)
assert.NoError(t, err)
assert.Equal(t, theuser.ID, int64(1))
assert.Empty(t, theuser.Email)
})
t.Run("by email", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
theuser, err := GetUser(&User{
Email: "user1@example.com",
})
s := db.NewSession()
defer s.Close()
theuser, err := getUser(
s,
&User{
Email: "user1@example.com",
},
false)
assert.NoError(t, err)
assert.Equal(t, theuser.ID, int64(1))
assert.Empty(t, theuser.Email)
})
t.Run("by id", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
theuser, err := GetUserByID(1)
s := db.NewSession()
defer s.Close()
theuser, err := GetUserByID(s, 1)
assert.NoError(t, err)
assert.Equal(t, theuser.ID, int64(1))
assert.Equal(t, theuser.Username, "user1")
@ -139,25 +179,37 @@ func TestGetUser(t *testing.T) {
})
t.Run("invalid id", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := GetUserByID(99999)
s := db.NewSession()
defer s.Close()
_, err := GetUserByID(s, 99999)
assert.Error(t, err)
assert.True(t, IsErrUserDoesNotExist(err))
})
t.Run("nonexistant", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := GetUserByID(0)
s := db.NewSession()
defer s.Close()
_, err := GetUserByID(s, 0)
assert.Error(t, err)
assert.True(t, IsErrUserDoesNotExist(err))
})
t.Run("empty name", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := GetUserByUsername("")
s := db.NewSession()
defer s.Close()
_, err := GetUserByUsername(s, "")
assert.Error(t, err)
assert.True(t, IsErrUserDoesNotExist(err))
})
t.Run("with email", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
theuser, err := GetUserWithEmail(&User{ID: 1})
s := db.NewSession()
defer s.Close()
theuser, err := GetUserWithEmail(s, &User{ID: 1})
assert.NoError(t, err)
assert.Equal(t, theuser.ID, int64(1))
assert.Equal(t, theuser.Username, "user1")
@ -168,42 +220,63 @@ func TestGetUser(t *testing.T) {
func TestCheckUserCredentials(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Username: "user1", Password: "1234"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Username: "user1", Password: "1234"})
assert.NoError(t, err)
})
t.Run("unverified email", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Username: "user5", Password: "1234"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Username: "user5", Password: "1234"})
assert.Error(t, err)
assert.True(t, IsErrEmailNotConfirmed(err))
})
t.Run("wrong password", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Username: "user1", Password: "12345"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Username: "user1", Password: "12345"})
assert.Error(t, err)
assert.True(t, IsErrWrongUsernameOrPassword(err))
})
t.Run("nonexistant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Username: "dfstestuu", Password: "1234"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Username: "dfstestuu", Password: "1234"})
assert.Error(t, err)
assert.True(t, IsErrWrongUsernameOrPassword(err))
})
t.Run("empty password", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Username: "user1"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Username: "user1"})
assert.Error(t, err)
assert.True(t, IsErrNoUsernamePassword(err))
})
t.Run("empty username", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Password: "1234"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Password: "1234"})
assert.Error(t, err)
assert.True(t, IsErrNoUsernamePassword(err))
})
t.Run("email", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := CheckUserCredentials(&Login{Username: "user1@example.com", Password: "1234"})
s := db.NewSession()
defer s.Close()
_, err := CheckUserCredentials(s, &Login{Username: "user1@example.com", Password: "1234"})
assert.NoError(t, err)
})
}
@ -211,7 +284,10 @@ func TestCheckUserCredentials(t *testing.T) {
func TestUpdateUser(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
uuser, err := UpdateUser(&User{
s := db.NewSession()
defer s.Close()
uuser, err := UpdateUser(s, &User{
ID: 1,
Password: "LoremIpsum",
Email: "testing@example.com",
@ -222,7 +298,10 @@ func TestUpdateUser(t *testing.T) {
})
t.Run("change username", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
uuser, err := UpdateUser(&User{
s := db.NewSession()
defer s.Close()
uuser, err := UpdateUser(s, &User{
ID: 1,
Username: "changedname",
})
@ -232,7 +311,10 @@ func TestUpdateUser(t *testing.T) {
})
t.Run("nonexistant", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
_, err := UpdateUser(&User{
s := db.NewSession()
defer s.Close()
_, err := UpdateUser(s, &User{
ID: 99999,
})
assert.Error(t, err)
@ -244,15 +326,20 @@ func TestUpdateUserPassword(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
err := UpdateUserPassword(&User{
s := db.NewSession()
defer s.Close()
err := UpdateUserPassword(s, &User{
ID: 1,
}, "12345",
)
}, "12345")
assert.NoError(t, err)
})
t.Run("nonexistant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
err := UpdateUserPassword(&User{
s := db.NewSession()
defer s.Close()
err := UpdateUserPassword(s, &User{
ID: 9999,
}, "12345")
assert.Error(t, err)
@ -260,10 +347,12 @@ func TestUpdateUserPassword(t *testing.T) {
})
t.Run("empty password", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
err := UpdateUserPassword(&User{
s := db.NewSession()
defer s.Close()
err := UpdateUserPassword(s, &User{
ID: 1,
}, "",
)
}, "")
assert.Error(t, err)
assert.True(t, IsErrEmptyNewPassword(err))
})
@ -272,14 +361,20 @@ func TestUpdateUserPassword(t *testing.T) {
func TestListUsers(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
all, err := ListUsers("user1")
s := db.NewSession()
defer s.Close()
all, err := ListUsers(s, "user1")
assert.NoError(t, err)
assert.True(t, len(all) > 0)
assert.Equal(t, all[0].Username, "user1")
})
t.Run("all users", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
all, err := ListUsers("")
s := db.NewSession()
defer s.Close()
all, err := ListUsers(s, "")
assert.NoError(t, err)
assert.Len(t, all, 14)
})
@ -288,39 +383,51 @@ func TestListUsers(t *testing.T) {
func TestUserPasswordReset(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
reset := &PasswordReset{
Token: "passwordresettesttoken",
NewPassword: "12345",
}
err := ResetPassword(reset)
err := ResetPassword(s, reset)
assert.NoError(t, err)
})
t.Run("without password", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
reset := &PasswordReset{
Token: "passwordresettesttoken",
}
err := ResetPassword(reset)
err := ResetPassword(s, reset)
assert.Error(t, err)
assert.True(t, IsErrNoUsernamePassword(err))
})
t.Run("empty token", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
reset := &PasswordReset{
Token: "somethingsomething",
NewPassword: "12345",
}
err := ResetPassword(reset)
err := ResetPassword(s, reset)
assert.Error(t, err)
assert.True(t, IsErrInvalidPasswordResetToken(err))
})
t.Run("wrong token", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
reset := &PasswordReset{
Token: "somethingsomething",
NewPassword: "12345",
}
err := ResetPassword(reset)
err := ResetPassword(s, reset)
assert.Error(t, err)
assert.True(t, IsErrInvalidPasswordResetToken(err))
})