Make sure the openid stuff uses a session

This commit is contained in:
kolaente 2020-12-23 01:43:44 +01:00
parent c3c7cab045
commit 996a207590
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
3 changed files with 39 additions and 11 deletions

View File

@ -17,11 +17,13 @@
package openid
import (
"code.vikunja.io/api/pkg/db"
"context"
"encoding/json"
"math/rand"
"net/http"
"time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
@ -130,8 +132,17 @@ func HandleCallback(c echo.Context) error {
return err
}
s := db.NewSession()
defer s.Close()
// Check if we have seen this user before
u, err := getOrCreateUser(cl, idToken.Issuer, idToken.Subject)
u, err := getOrCreateUser(s, cl, idToken.Issuer, idToken.Subject)
if err != nil {
_ = s.Rollback()
return err
}
err = s.Commit()
if err != nil {
return err
}
@ -140,7 +151,7 @@ func HandleCallback(c echo.Context) error {
return auth.NewUserAuthTokenResponse(u, c)
}
func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err error) {
func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *user.User, err error) {
// Check if the user exists for that issuer and subject
u, err = user.GetUserWithEmail(s, &user.User{
Issuer: issuer,

View File

@ -26,12 +26,18 @@ import (
func TestGetOrCreateUser(t *testing.T) {
t.Run("new user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "test@example.com",
PreferredUsername: "someUserWhoDoesNotExistYet",
}
u, err := getOrCreateUser(cl, "https://some.issuer", "12345")
u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID,
"email": cl.Email,
@ -40,13 +46,19 @@ func TestGetOrCreateUser(t *testing.T) {
})
t.Run("new user, no username provided", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "test@example.com",
PreferredUsername: "",
}
u, err := getOrCreateUser(cl, "https://some.issuer", "12345")
u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.NoError(t, err)
assert.NotEmpty(t, u.Username)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID,
"email": cl.Email,
@ -54,19 +66,28 @@ func TestGetOrCreateUser(t *testing.T) {
})
t.Run("new user, no email address", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "",
}
_, err := getOrCreateUser(cl, "https://some.issuer", "12345")
_, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.Error(t, err)
})
t.Run("existing user, different email address", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "other-email-address@some.service.com",
}
u, err := getOrCreateUser(cl, "https://some.service.com", "12345")
u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345")
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID,
"email": cl.Email,

View File

@ -174,11 +174,7 @@ func NewEcho() *echo.Echo {
})
handler.SetLoggingProvider(log.GetLogger())
handler.SetMaxItemsPerPage(config.ServiceMaxItemsPerPage.GetInt())
x, err := db.CreateDBEngine()
if err != nil {
log.Criticalf("Could not get db engine for handler: %s", err)
}
handler.SetSessionFactory(x.NewSession)
handler.SetSessionFactory(db.NewSession)
return e
}