Add retreiving auth tokens

This commit is contained in:
kolaente 2020-10-25 19:38:24 +01:00
parent 65264f7948
commit 05e4bfdca0
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
3 changed files with 65 additions and 37 deletions

View File

@ -238,3 +238,5 @@ auth:
authurl: authurl:
# The client ID used to authenticate Vikunja at the OpenID Connect provider. # The client ID used to authenticate Vikunja at the OpenID Connect provider.
clientid: clientid:
# The client secret used to authenticate Vikunja at the OpenID Connect provider.
clientsecret:

View File

@ -18,6 +18,7 @@ package openid
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@ -52,38 +53,32 @@ func getKeyFromName(name string) string {
return reg.ReplaceAllString(strings.ToLower(name), "") return reg.ReplaceAllString(strings.ToLower(name), "")
} }
func GetAllProviders() (providers []*Provider) { func GetAllProviders() (providers []*Provider, err error) {
rawProvider := config.AuthOpenIDProviders.Get().([]interface{}) rawProvider := config.AuthOpenIDProviders.Get().([]interface{})
for _, p := range rawProvider { for _, p := range rawProvider {
pi := p.(map[interface{}]interface{}) pi := p.(map[interface{}]interface{})
providers = append(providers, &Provider{ provider, err := getProviderFromMap(pi)
Name: pi["name"].(string), if err != nil {
Key: getKeyFromName(pi["name"].(string)), return nil, err
AuthURL: pi["authurl"].(string), }
ClientID: pi["clientid"].(string),
}) providers = append(providers, provider)
} }
return return
} }
func GetProvider(key string) (*Provider, error) { func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
rawProvider := config.AuthOpenIDProviders.Get().([]interface{})
for _, p := range rawProvider {
pi := p.(map[interface{}]interface{})
k := getKeyFromName(pi["name"].(string)) k := getKeyFromName(pi["name"].(string))
if k == key {
provider := &Provider{ provider := &Provider{
Name: pi["name"].(string), Name: pi["name"].(string),
Key: k, Key: k,
AuthURL: pi["authurl"].(string), AuthURL: pi["authurl"].(string),
ClientID: pi["clientid"].(string), ClientID: pi["clientid"].(string),
// TODO ClientSecret: pi["clientsecret"].(string),
// ClientSecret
} }
var err error var err error
@ -104,8 +99,21 @@ func GetProvider(key string) (*Provider, error) {
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
} }
provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL
return provider, nil return provider, nil
} }
func GetProvider(key string) (*Provider, error) {
rawProvider := config.AuthOpenIDProviders.Get().([]interface{})
for _, p := range rawProvider {
pi := p.(map[interface{}]interface{})
k := getKeyFromName(pi["name"].(string))
if k == key {
return getProviderFromMap(pi)
}
} }
return nil, nil return nil, nil
@ -131,7 +139,20 @@ func HandleCallback(c echo.Context) error {
// Parse the access & ID token // Parse the access & ID token
oauth2Token, err := provider.Oauth2Config.Exchange(context.Background(), cb.Code) oauth2Token, err := provider.Oauth2Config.Exchange(context.Background(), cb.Code)
if err != nil { if err != nil {
if rerr, is := err.(*oauth2.RetrieveError); is {
log.Error(err) log.Error(err)
details := make(map[string]interface{})
if err := json.Unmarshal(rerr.Body, &details); err != nil {
return err
}
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"message": "Could not authenticate against third party.",
"details": details,
})
}
return err return err
} }

View File

@ -91,11 +91,16 @@ func Info(c echo.Context) error {
OpenIDConnect: openIDAuthInfo{ OpenIDConnect: openIDAuthInfo{
Enabled: config.AuthOpenIDEnabled.GetBool(), Enabled: config.AuthOpenIDEnabled.GetBool(),
RedirectURL: config.AuthOpenIDRedirectURL.GetString(), RedirectURL: config.AuthOpenIDRedirectURL.GetString(),
Providers: openid.GetAllProviders(),
}, },
}, },
} }
var err error
info.AuthInfo.OpenIDConnect.Providers, err = openid.GetAllProviders()
if err != nil {
return err
}
// Migrators // Migrators
if config.MigrationWunderlistEnable.GetBool() { if config.MigrationWunderlistEnable.GetBool() {
m := &wunderlist.Migration{} m := &wunderlist.Migration{}