Use a session for everything in users

This commit is contained in:
kolaente 2020-12-23 01:32:33 +01:00
parent e3e7021c8c
commit 1cab1f63af
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
31 changed files with 122 additions and 107 deletions

View File

@ -97,7 +97,7 @@ func getUserFromArg(arg string) *user.User {
log.Fatalf("Invalid user id: %s", err)
}
u, err := user.GetUserByID(id)
u, err := user.GetUserByID(s, id)
if err != nil {
log.Fatalf("Could not get user: %s", err)
}
@ -116,7 +116,7 @@ var userListCmd = &cobra.Command{
initialize.FullInit()
},
Run: func(cmd *cobra.Command, args []string) {
users, err := user.ListUsers("")
users, err := user.ListUsers(s, "")
if err != nil {
log.Fatalf("Error getting users: %s", err)
}
@ -158,7 +158,7 @@ var userCreateCmd = &cobra.Command{
Email: userFlagEmail,
Password: getPasswordFromFlagOrInput(),
}
newUser, err := user.CreateUser(u)
newUser, err := user.CreateUser(s, u)
if err != nil {
log.Fatalf("Error creating new user: %s", err)
}
@ -192,7 +192,7 @@ var userUpdateCmd = &cobra.Command{
u.AvatarProvider = userFlagAvatar
}
_, err := user.UpdateUser(u)
_, err := user.UpdateUser(s, u)
if err != nil {
log.Fatalf("Error updating the user: %s", err)
}
@ -213,13 +213,13 @@ var userResetPasswordCmd = &cobra.Command{
// By default we reset as usual, only with specific flag directly.
if userFlagResetPasswordDirectly {
err := user.UpdateUserPassword(u, getPasswordFromFlagOrInput())
err := user.UpdateUserPassword(s, u, getPasswordFromFlagOrInput())
if err != nil {
log.Fatalf("Could not update user password: %s", err)
}
fmt.Println("Password updated successfully.")
} else {
err := user.RequestUserPasswordResetToken(u)
err := user.RequestUserPasswordResetToken(s, u)
if err != nil {
log.Fatalf("Could not send password reset email: %s", err)
}
@ -245,7 +245,7 @@ var userChangeEnabledCmd = &cobra.Command{
} else {
u.IsActive = !u.IsActive
}
_, err := user.UpdateUser(u)
_, err := user.UpdateUser(s, u)
if err != nil {
log.Fatalf("Could not enable the user")
}

View File

@ -185,7 +185,7 @@ func (l *Label) ReadOne(s *xorm.Session) (err error) {
}
*l = *label
u, err := user.GetUserByIDSess(s, l.CreatedByID)
u, err := user.GetUserByID(s, l.CreatedByID)
if err != nil {
return err
}

View File

@ -207,7 +207,7 @@ func (l *List) ReadOne(s *xorm.Session) (err error) {
}
// Get list owner
l.Owner, err = user.GetUserByIDSess(s, l.OwnerID)
l.Owner, err = user.GetUserByID(s, l.OwnerID)
if err != nil {
return err
}
@ -295,7 +295,7 @@ type listOptions struct {
// Gets the lists only, without any tasks or so
func getRawListsForUser(s *xorm.Session, opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) {
fullUser, err := user.GetUserByIDSess(s, opts.user.ID)
fullUser, err := user.GetUserByID(s, opts.user.ID)
if err != nil {
return nil, 0, 0, err
}

View File

@ -86,7 +86,7 @@ func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) {
}
// Check if the user exists
u, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(s, lu.Username)
if err != nil {
return err
}
@ -132,7 +132,7 @@ func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) {
func (lu *ListUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists
u, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(s, lu.Username)
if err != nil {
return
}
@ -239,7 +239,7 @@ func (lu *ListUser) Update(s *xorm.Session) (err error) {
}
// Check if the user exists
u, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(s, lu.Username)
if err != nil {
return err
}

View File

@ -128,7 +128,7 @@ func GetNamespaceByID(s *xorm.Session, id int64) (namespace *Namespace, err erro
}
// Get the namespace Owner
namespace.Owner, err = user.GetUserByID(namespace.OwnerID)
namespace.Owner, err = user.GetUserByID(s, namespace.OwnerID)
return
}
@ -460,7 +460,7 @@ func (n *Namespace) Create(s *xorm.Session, a web.Auth) (err error) {
n.ID = 0 // This would otherwise prevent the creation of new lists after one was created
// Check if the User exists
n.Owner, err = user.GetUserByID(a.GetID())
n.Owner, err = user.GetUserByID(s, a.GetID())
if err != nil {
return
}
@ -576,7 +576,7 @@ func (n *Namespace) Update(s *xorm.Session) (err error) {
if n.Owner != nil {
n.OwnerID = n.Owner.ID
if currentNamespace.OwnerID != n.OwnerID {
n.Owner, err = user.GetUserByID(n.OwnerID)
n.Owner, err = user.GetUserByID(s, n.OwnerID)
if err != nil {
return
}

View File

@ -81,7 +81,7 @@ func (nu *NamespaceUser) Create(s *xorm.Session, a web.Auth) (err error) {
}
// Check if the user exists
user, err := user2.GetUserByUsername(nu.Username)
user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil {
return err
}
@ -125,7 +125,7 @@ func (nu *NamespaceUser) Create(s *xorm.Session, a web.Auth) (err error) {
func (nu *NamespaceUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists
user, err := user2.GetUserByUsername(nu.Username)
user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil {
return
}
@ -228,7 +228,7 @@ func (nu *NamespaceUser) Update(s *xorm.Session) (err error) {
}
// Check if the user exists
user, err := user2.GetUserByUsername(nu.Username)
user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil {
return err
}

View File

@ -135,7 +135,7 @@ func getSavedFilterSimpleByID(s *xorm.Session, id int64) (sf *SavedFilter, err e
// @Router /filters/{id} [get]
func (sf *SavedFilter) ReadOne(s *xorm.Session) error {
// s already contains almost the full saved filter from the rights check, we only need to add the user
u, err := user.GetUserByID(sf.OwnerID)
u, err := user.GetUserByID(s, sf.OwnerID)
sf.Owner = u
return err
}

View File

@ -203,7 +203,7 @@ func (la *TaskAssginee) Create(s *xorm.Session, a web.Auth) (err error) {
func (t *Task) addNewAssigneeByID(s *xorm.Session, newAssigneeID int64, list *List) (err error) {
// Check if the user exists and has access to the list
newAssignee, err := user.GetUserByID(newAssigneeID)
newAssignee, err := user.GetUserByID(s, newAssigneeID)
if err != nil {
return err
}

View File

@ -70,7 +70,7 @@ func (tc *TaskComment) Create(s *xorm.Session, a web.Auth) (err error) {
if err != nil {
return
}
tc.Author, err = user.GetUserByID(a.GetID())
tc.Author, err = user.GetUserByID(s, a.GetID())
return
}

View File

@ -645,7 +645,7 @@ func addMoreInfoToTasks(s *xorm.Session, taskMap map[int64]*Task) (err error) {
return
}
users, err := user.GetUsersByIDs(userIDs)
users, err := user.GetUsersByIDs(s, userIDs)
if err != nil {
return
}
@ -761,7 +761,7 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
// A negative user id indicates user share links
t.CreatedByID = a.GetID() * -1
} else {
u, err := user.GetUserByID(a.GetID())
u, err := user.GetUserByID(s, a.GetID())
if err != nil {
return err
}

View File

@ -45,7 +45,7 @@ func (tm *TeamMember) Create(s *xorm.Session, a web.Auth) (err error) {
}
// Check if the user exists
user, err := user2.GetUserByUsername(tm.Username)
user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil {
return
}
@ -89,7 +89,7 @@ func (tm *TeamMember) Delete(s *xorm.Session) (err error) {
}
// Find the numeric user id
user, err := user2.GetUserByUsername(tm.Username)
user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil {
return
}
@ -112,7 +112,7 @@ func (tm *TeamMember) Delete(s *xorm.Session) (err error) {
// @Router /teams/{id}/members/{userID}/admin [post]
func (tm *TeamMember) Update(s *xorm.Session) (err error) {
// Find the numeric user id
user, err := user2.GetUserByUsername(tm.Username)
user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil {
return
}

View File

@ -142,7 +142,7 @@ func HandleCallback(c echo.Context) error {
func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err error) {
// Check if the user exists for that issuer and subject
u, err = user.GetUserWithEmail(&user.User{
u, err = user.GetUserWithEmail(s, &user.User{
Issuer: issuer,
Subject: subject,
})
@ -165,7 +165,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
uu.Username = petname.Generate(3, "-")
}
u, err = user.CreateUser(uu)
u, err = user.CreateUser(s, uu)
if err != nil && !user.IsErrUsernameExists(err) {
return nil, err
}
@ -173,7 +173,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
// If their preferred username is already taken, create some random one from the email and subject
if user.IsErrUsernameExists(err) {
uu.Username = petname.Generate(3, "-")
u, err = user.CreateUser(uu)
u, err = user.CreateUser(s, uu)
if err != nil {
return nil, err
}
@ -196,7 +196,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
if cl.Name != u.Name {
u.Name = cl.Name
}
u, err = user.UpdateUser(&user.User{
u, err = user.UpdateUser(s, &user.User{
ID: u.ID,
Email: u.Email,
Name: u.Name,

View File

@ -57,7 +57,7 @@ func GetAvatar(c echo.Context) error {
username := c.Param("username")
// Get the user
u, err := user.GetUserWithEmail(&user.User{Username: username})
u, err := user.GetUserWithEmail(s, &user.User{Username: username})
if err != nil {
log.Errorf("Error getting user for avatar: %v", err)
return handler.HandleHTTPError(err, c)
@ -117,7 +117,7 @@ func UploadAvatar(c echo.Context) (err error) {
if err != nil {
return handler.HandleHTTPError(err, c)
}
u, err := user.GetUserByID(uc.ID)
u, err := user.GetUserByID(s, uc.ID)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -180,7 +180,7 @@ func UploadAvatar(c echo.Context) (err error) {
u.AvatarFileID = f.ID
u.AvatarProvider = "upload"
if _, err := user.UpdateUser(u); err != nil {
if _, err := user.UpdateUser(s, u); err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -46,18 +46,18 @@ func Login(c echo.Context) error {
}
// Check user
user, err := user2.CheckUserCredentials(&u)
user, err := user2.CheckUserCredentials(s, &u)
if err != nil {
return handler.HandleHTTPError(err, c)
}
totpEnabled, err := user2.TOTPEnabledForUser(user)
totpEnabled, err := user2.TOTPEnabledForUser(s, user)
if err != nil {
return handler.HandleHTTPError(err, c)
}
if totpEnabled {
_, err = user2.ValidateTOTPPasscode(&user2.TOTPPasscode{
_, err = user2.ValidateTOTPPasscode(s, &user2.TOTPPasscode{
User: user,
Passcode: u.TOTPPasscode,
})
@ -104,7 +104,7 @@ func RenewToken(c echo.Context) (err error) {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -43,7 +43,7 @@ func UserConfirmEmail(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "No token provided.")
}
err := user.ConfirmEmail(&emailConfirm)
err := user.ConfirmEmail(s, &emailConfirm)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -43,7 +43,7 @@ func UserResetPassword(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "No password provided.")
}
err := user.ResetPassword(&pwReset)
err := user.ResetPassword(s, &pwReset)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -73,7 +73,7 @@ func UserRequestResetPasswordToken(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err)
}
err := user.RequestUserPasswordResetTokenByEmail(&pwTokenReset)
err := user.RequestUserPasswordResetTokenByEmail(s, &pwTokenReset)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -51,7 +51,7 @@ func RegisterUser(c echo.Context) error {
}
// Insert the user
newUser, err := user.CreateUser(datUser.APIFormat())
newUser, err := user.CreateUser(s, datUser.APIFormat())
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -57,7 +57,7 @@ func GetUserAvatarProvider(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -91,14 +91,14 @@ func ChangeUserAvatarProvider(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
return handler.HandleHTTPError(err, c)
}
user.AvatarProvider = uap.AvatarProvider
_, err = user2.UpdateUser(user)
_, err = user2.UpdateUser(s, user)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -129,7 +129,7 @@ func UpdateGeneralUserSettings(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -137,7 +137,7 @@ func UpdateGeneralUserSettings(c echo.Context) error {
user.Name = us.Name
user.EmailRemindersEnabled = us.EmailRemindersEnabled
_, err = user2.UpdateUser(user)
_, err = user2.UpdateUser(s, user)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -41,7 +41,7 @@ func UserShow(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, "Error getting current user.")
}
user, err := user2.GetUserByID(userInfos.ID)
user, err := user2.GetUserByID(s, userInfos.ID)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -47,7 +47,7 @@ func UserTOTPEnroll(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
t, err := user.EnrollTOTP(u)
t, err := user.EnrollTOTP(s, u)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -86,7 +86,7 @@ func UserTOTPEnable(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.")
}
err = user.EnableTOTP(passcode)
err = user.EnableTOTP(s, passcode)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -122,7 +122,7 @@ func UserTOTPDisable(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
u, err = user.GetUserByID(u.ID)
u, err = user.GetUserByID(s, u.ID)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -132,7 +132,7 @@ func UserTOTPDisable(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
err = user.DisableTOTP(u)
err = user.DisableTOTP(s, u)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -156,7 +156,7 @@ func UserTOTPQrCode(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
qrcode, err := user.GetTOTPQrCodeForUser(u)
qrcode, err := user.GetTOTPQrCodeForUser(s, u)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -186,7 +186,7 @@ func UserTOTP(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
t, err := user.GetTOTPForUser(u)
t, err := user.GetTOTPForUser(s, u)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -56,7 +56,7 @@ func UpdateUserEmail(c echo.Context) (err error) {
return handler.HandleHTTPError(err, c)
}
emailUpdate.User, err = user.CheckUserCredentials(&user.Login{
emailUpdate.User, err = user.CheckUserCredentials(s, &user.Login{
Username: emailUpdate.User.Username,
Password: emailUpdate.Password,
})
@ -64,7 +64,7 @@ func UpdateUserEmail(c echo.Context) (err error) {
return handler.HandleHTTPError(err, c)
}
err = user.UpdateEmail(emailUpdate)
err = user.UpdateEmail(s, emailUpdate)
if err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -62,12 +62,12 @@ func UserChangePassword(c echo.Context) error {
}
// Check the current password
if _, err = user.CheckUserCredentials(&user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil {
if _, err = user.CheckUserCredentials(s, &user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil {
return handler.HandleHTTPError(err, c)
}
// Update the password
if err = user.UpdateUserPassword(doer, newPW.NewPassword); err != nil {
if err = user.UpdateUserPassword(s, doer, newPW.NewPassword); err != nil {
return handler.HandleHTTPError(err, c)
}

View File

@ -605,7 +605,7 @@ func caldavBasicAuth(username, password string, c echo.Context) (bool, error) {
Username: username,
Password: password,
}
u, err := user.CheckUserCredentials(creds)
u, err := user.CheckUserCredentials(s, creds)
if err != nil {
log.Errorf("Error during basic auth for caldav: %v", err)
return false, nil

View File

@ -20,20 +20,10 @@ package user
import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"xorm.io/xorm"
)
var x *xorm.Engine
// InitDB sets up the database connection to use in this module
func InitDB() (err error) {
x, err = db.CreateDBEngine()
if err != nil {
log.Criticalf("Could not connect to db: %v", err.Error())
return
}
// Cache
if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" {
db.RegisterTableStructsForCache(GetTables())

View File

@ -25,12 +25,15 @@ import (
// InitTests handles the actual bootstrapping of the test env
func InitTests() {
var err error
x, err = db.CreateTestEngine()
s := db.NewSession()
defer s.Close()
err = s.Sync2(GetTables()...)
if err != nil {
log.Fatal(err)
}
err = x.Sync2(GetTables()...)
err = s.Commit()
if err != nil {
log.Fatal(err)
}

View File

@ -18,6 +18,7 @@ package user
import (
"image"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config"
"github.com/pquerna/otp"
@ -47,19 +48,19 @@ type TOTPPasscode struct {
}
// TOTPEnabledForUser checks if totp is enabled for a user - not if it is activated, use GetTOTPForUser to check that.
func TOTPEnabledForUser(user *User) (bool, error) {
func TOTPEnabledForUser(s *xorm.Session, user *User) (bool, error) {
if !config.ServiceEnableTotp.GetBool() {
return false, nil
}
t := &TOTP{}
_, err := x.Where("user_id = ?", user.ID).Get(t)
_, err := s.Where("user_id = ?", user.ID).Get(t)
return t.Enabled, err
}
// GetTOTPForUser returns the current state of totp settings for the user.
func GetTOTPForUser(user *User) (t *TOTP, err error) {
func GetTOTPForUser(s *xorm.Session, user *User) (t *TOTP, err error) {
t = &TOTP{}
exists, err := x.Where("user_id = ?", user.ID).Get(t)
exists, err := s.Where("user_id = ?", user.ID).Get(t)
if err != nil {
return
}
@ -71,8 +72,8 @@ func GetTOTPForUser(user *User) (t *TOTP, err error) {
}
// EnrollTOTP creates a new TOTP entry for the user - it does not enable it yet.
func EnrollTOTP(user *User) (t *TOTP, err error) {
isEnrolled, err := x.Where("user_id = ?", user.ID).Exist(&TOTP{})
func EnrollTOTP(s *xorm.Session, user *User) (t *TOTP, err error) {
isEnrolled, err := s.Where("user_id = ?", user.ID).Exist(&TOTP{})
if err != nil {
return
}
@ -94,18 +95,18 @@ func EnrollTOTP(user *User) (t *TOTP, err error) {
Enabled: false,
URL: key.URL(),
}
_, err = x.Insert(t)
_, err = s.Insert(t)
return
}
// EnableTOTP enables totp for a user. The provided passcode is used to verify the user has a working totp setup.
func EnableTOTP(passcode *TOTPPasscode) (err error) {
t, err := ValidateTOTPPasscode(passcode)
func EnableTOTP(s *xorm.Session, passcode *TOTPPasscode) (err error) {
t, err := ValidateTOTPPasscode(s, passcode)
if err != nil {
return
}
_, err = x.
_, err = s.
Where("id = ?", t.ID).
Cols("enabled").
Update(&TOTP{Enabled: true})
@ -113,14 +114,16 @@ func EnableTOTP(passcode *TOTPPasscode) (err error) {
}
// DisableTOTP removes all totp settings for a user.
func DisableTOTP(user *User) (err error) {
_, err = x.Where("user_id = ?", user.ID).Delete(&TOTP{})
func DisableTOTP(s *xorm.Session, user *User) (err error) {
_, err = s.
Where("user_id = ?", user.ID).
Delete(&TOTP{})
return
}
// ValidateTOTPPasscode validated totp codes of users.
func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) {
t, err = GetTOTPForUser(passcode.User)
func ValidateTOTPPasscode(s *xorm.Session, passcode *TOTPPasscode) (t *TOTP, err error) {
t, err = GetTOTPForUser(s, passcode.User)
if err != nil {
return
}
@ -133,8 +136,8 @@ func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) {
}
// GetTOTPQrCodeForUser returns a qrcode for a user's totp setting
func GetTOTPQrCodeForUser(user *User) (qrcode image.Image, err error) {
t, err := GetTOTPForUser(user)
func GetTOTPQrCodeForUser(s *xorm.Session, user *User) (qrcode image.Image, err error) {
t, err := GetTOTPForUser(s, user)
if err != nil {
return
}

View File

@ -20,6 +20,7 @@ import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/mail"
"code.vikunja.io/api/pkg/utils"
"xorm.io/xorm"
)
// EmailUpdate is the data structure to update a user's email address
@ -32,11 +33,11 @@ type EmailUpdate struct {
}
// UpdateEmail lets a user update their email address
func UpdateEmail(update *EmailUpdate) (err error) {
func UpdateEmail(s *xorm.Session, update *EmailUpdate) (err error) {
// Check the email is not already used
user := &User{}
has, err := x.Where("email = ?", update.NewEmail).Get(user)
has, err := s.Where("email = ?", update.NewEmail).Get(user)
if err != nil {
return
}
@ -46,7 +47,7 @@ func UpdateEmail(update *EmailUpdate) (err error) {
}
// Set the user as unconfirmed and the new email address
update.User, err = GetUserWithEmail(&User{ID: update.User.ID})
update.User, err = GetUserWithEmail(s, &User{ID: update.User.ID})
if err != nil {
return
}
@ -54,7 +55,7 @@ func UpdateEmail(update *EmailUpdate) (err error) {
update.User.IsActive = false
update.User.Email = update.NewEmail
update.User.EmailConfirmToken = utils.MakeRandomString(64)
_, err = x.
_, err = s.
Where("id = ?", update.User.ID).
Cols("email", "is_active", "email_confirm_token").
Update(update.User)

View File

@ -17,6 +17,8 @@
package user
import "xorm.io/xorm"
// EmailConfirm holds the token to confirm a mail address
type EmailConfirm struct {
// The email confirm token sent via email.
@ -24,7 +26,7 @@ type EmailConfirm struct {
}
// ConfirmEmail handles the confirmation of an email address
func ConfirmEmail(c *EmailConfirm) (err error) {
func ConfirmEmail(s *xorm.Session, c *EmailConfirm) (err error) {
// Check if we have an email confirm token
if c.Token == "" {
@ -33,7 +35,9 @@ func ConfirmEmail(c *EmailConfirm) (err error) {
// Check if the token is valid
user := User{}
has, err := x.Where("email_confirm_token = ?", c.Token).Get(&user)
has, err := s.
Where("email_confirm_token = ?", c.Token).
Get(&user)
if err != nil {
return
}
@ -44,6 +48,9 @@ func ConfirmEmail(c *EmailConfirm) (err error) {
user.IsActive = true
user.EmailConfirmToken = ""
_, err = x.Where("id = ?", user.ID).Cols("is_active", "email_confirm_token").Update(&user)
_, err = s.
Where("id = ?", user.ID).
Cols("is_active", "email_confirm_token").
Update(&user)
return
}

View File

@ -65,7 +65,10 @@ func TestUserEmailConfirm(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
if err := ConfirmEmail(tt.args.c); (err != nil) != tt.wantErr {
s := db.NewSession()
defer s.Close()
if err := ConfirmEmail(s, tt.args.c); (err != nil) != tt.wantErr {
t.Errorf("ConfirmEmail() error = %v, wantErr %v", err, tt.wantErr)
}
})

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/mail"
"code.vikunja.io/api/pkg/utils"
"xorm.io/xorm"
)
// PasswordReset holds the data to reset a password
@ -32,7 +33,7 @@ type PasswordReset struct {
}
// ResetPassword resets a users password
func ResetPassword(reset *PasswordReset) (err error) {
func ResetPassword(s *xorm.Session, reset *PasswordReset) (err error) {
// Check if the password is not empty
if reset.NewPassword == "" {
@ -41,7 +42,9 @@ func ResetPassword(reset *PasswordReset) (err error) {
// Check if we have a token
var user User
exists, err := x.Where("password_reset_token = ?", reset.Token).Get(&user)
exists, err := s.
Where("password_reset_token = ?", reset.Token).
Get(&user)
if err != nil {
return
}
@ -57,7 +60,9 @@ func ResetPassword(reset *PasswordReset) (err error) {
}
// Save it
_, err = x.Where("id = ?", user.ID).Update(&user)
_, err = s.
Where("id = ?", user.ID).
Update(&user)
if err != nil {
return
}
@ -83,27 +88,29 @@ type PasswordTokenRequest struct {
}
// RequestUserPasswordResetTokenByEmail inserts a random token to reset a users password into the databsse
func RequestUserPasswordResetTokenByEmail(tr *PasswordTokenRequest) (err error) {
func RequestUserPasswordResetTokenByEmail(s *xorm.Session, tr *PasswordTokenRequest) (err error) {
if tr.Email == "" {
return ErrNoUsernamePassword{}
}
// Check if the user exists
user, err := GetUserWithEmail(&User{Email: tr.Email})
user, err := GetUserWithEmail(s, &User{Email: tr.Email})
if err != nil {
return
}
return RequestUserPasswordResetToken(user)
return RequestUserPasswordResetToken(s, user)
}
// RequestUserPasswordResetToken sends a user a password reset email.
func RequestUserPasswordResetToken(user *User) (err error) {
func RequestUserPasswordResetToken(s *xorm.Session, user *User) (err error) {
// Generate a token and save it
user.PasswordResetToken = utils.MakeRandomString(400)
// Save it
_, err = x.Where("id = ?", user.ID).Update(user)
_, err = s.
Where("id = ?", user.ID).
Update(user)
if err != nil {
return
}

View File

@ -20,12 +20,13 @@ package user
import (
"strconv"
"strings"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/log"
)
// ListUsers returns a list with all users, filtered by an optional searchstring
func ListUsers(searchterm string) (users []*User, err error) {
func ListUsers(s *xorm.Session, searchterm string) (users []*User, err error) {
vals := strings.Split(searchterm, ",")
ids := []int64{}
@ -39,18 +40,18 @@ func ListUsers(searchterm string) (users []*User, err error) {
}
if len(ids) > 0 {
err = x.
err = s.
In("id", ids).
Find(&users)
return
}
if searchterm == "" {
err = x.Find(&users)
err = s.Find(&users)
return
}
err = x.
err = s.
Where("username LIKE ?", "%"+searchterm+"%").
Find(&users)
return