diff --git a/pkg/db/db.go b/pkg/db/db.go index ea3ca4c11..bcfc3c91e 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/log" "encoding/gob" "fmt" + "net/url" "strconv" "strings" "time" @@ -120,31 +121,43 @@ func initMysqlEngine() (engine *xorm.Engine, err error) { return } +// parsePostgreSQLHostPort parses given input in various forms defined in +// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING +// and returns proper host and port number. +func parsePostgreSQLHostPort(info string) (string, string) { + host, port := "127.0.0.1", "5432" + if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") { + idx := strings.LastIndex(info, ":") + host = info[:idx] + port = info[idx+1:] + } else if len(info) > 0 { + host = info + } + return host, port +} + func initPostgresEngine() (engine *xorm.Engine, err error) { - var connStr strings.Builder - - // https://pkg.go.dev/github.com/lib/pq?tab=doc#hdr-Connection_String_Parameters - params := map[string]string{ - "user": config.DatabaseUser.GetString(), - "password": config.DatabasePassword.GetString(), - "host": config.DatabaseHost.GetString(), - "dbname": config.DatabaseDatabase.GetString(), - "sslmode": config.DatabaseSslMode.GetString(), + var connStr string + host, port := parsePostgreSQLHostPort(config.DatabaseHost.GetString()) + if strings.HasPrefix(config.DatabaseHost.GetString(), "/") { // looks like a unix socket + connStr = fmt.Sprintf("postgres://%s:%s@:%s/%ssslmode=%s&host=%s", + url.PathEscape(config.DatabaseUser.GetString()), + url.PathEscape(config.DatabasePassword.GetString()), + port, + config.DatabaseDatabase.GetString(), + config.DatabaseSslMode.GetString(), + host) + } else { + connStr = fmt.Sprintf("postgres://%s:%s@%s:%s/%ssslmode=%s", + url.PathEscape(config.DatabaseUser.GetString()), + url.PathEscape(config.DatabasePassword.GetString()), + host, + port, + config.DatabaseDatabase.GetString(), + config.DatabaseSslMode.GetString()) } - for name, value := range params { - if name != "" { - value = strings.ReplaceAll(value, "\\", "\\\\") - value = strings.ReplaceAll(value, "'", "\\'") - - if connStr.Len() > 0 { - connStr.WriteString(" ") - } - connStr.WriteString(fmt.Sprintf("%s='%s'", name, value)) - } - } - - engine, err = xorm.NewEngine("postgres", connStr.String()) + engine, err = xorm.NewEngine("postgres", connStr) if err != nil { return }