Make sure the openid stuff uses a session
This commit is contained in:
parent
c3c7cab045
commit
996a207590
|
@ -17,11 +17,13 @@
|
|||
package openid
|
||||
|
||||
import (
|
||||
"code.vikunja.io/api/pkg/db"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"code.vikunja.io/api/pkg/log"
|
||||
"code.vikunja.io/api/pkg/models"
|
||||
|
@ -130,8 +132,17 @@ func HandleCallback(c echo.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
// Check if we have seen this user before
|
||||
u, err := getOrCreateUser(cl, idToken.Issuer, idToken.Subject)
|
||||
u, err := getOrCreateUser(s, cl, idToken.Issuer, idToken.Subject)
|
||||
if err != nil {
|
||||
_ = s.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -140,7 +151,7 @@ func HandleCallback(c echo.Context) error {
|
|||
return auth.NewUserAuthTokenResponse(u, c)
|
||||
}
|
||||
|
||||
func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err error) {
|
||||
func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *user.User, err error) {
|
||||
// Check if the user exists for that issuer and subject
|
||||
u, err = user.GetUserWithEmail(s, &user.User{
|
||||
Issuer: issuer,
|
||||
|
|
|
@ -26,12 +26,18 @@ import (
|
|||
func TestGetOrCreateUser(t *testing.T) {
|
||||
t.Run("new user", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
cl := &claims{
|
||||
Email: "test@example.com",
|
||||
PreferredUsername: "someUserWhoDoesNotExistYet",
|
||||
}
|
||||
u, err := getOrCreateUser(cl, "https://some.issuer", "12345")
|
||||
u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
|
||||
assert.NoError(t, err)
|
||||
err = s.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
db.AssertExists(t, "users", map[string]interface{}{
|
||||
"id": u.ID,
|
||||
"email": cl.Email,
|
||||
|
@ -40,13 +46,19 @@ func TestGetOrCreateUser(t *testing.T) {
|
|||
})
|
||||
t.Run("new user, no username provided", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
cl := &claims{
|
||||
Email: "test@example.com",
|
||||
PreferredUsername: "",
|
||||
}
|
||||
u, err := getOrCreateUser(cl, "https://some.issuer", "12345")
|
||||
u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, u.Username)
|
||||
err = s.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
db.AssertExists(t, "users", map[string]interface{}{
|
||||
"id": u.ID,
|
||||
"email": cl.Email,
|
||||
|
@ -54,19 +66,28 @@ func TestGetOrCreateUser(t *testing.T) {
|
|||
})
|
||||
t.Run("new user, no email address", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
cl := &claims{
|
||||
Email: "",
|
||||
}
|
||||
_, err := getOrCreateUser(cl, "https://some.issuer", "12345")
|
||||
_, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("existing user, different email address", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
cl := &claims{
|
||||
Email: "other-email-address@some.service.com",
|
||||
}
|
||||
u, err := getOrCreateUser(cl, "https://some.service.com", "12345")
|
||||
u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345")
|
||||
assert.NoError(t, err)
|
||||
err = s.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
db.AssertExists(t, "users", map[string]interface{}{
|
||||
"id": u.ID,
|
||||
"email": cl.Email,
|
||||
|
|
|
@ -174,11 +174,7 @@ func NewEcho() *echo.Echo {
|
|||
})
|
||||
handler.SetLoggingProvider(log.GetLogger())
|
||||
handler.SetMaxItemsPerPage(config.ServiceMaxItemsPerPage.GetInt())
|
||||
x, err := db.CreateDBEngine()
|
||||
if err != nil {
|
||||
log.Criticalf("Could not get db engine for handler: %s", err)
|
||||
}
|
||||
handler.SetSessionFactory(x.NewSession)
|
||||
handler.SetSessionFactory(db.NewSession)
|
||||
|
||||
return e
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue