Use sprintf to create the connection string for postgresql

Signed-off-by: kolaente <k@knt.li>
This commit is contained in:
kolaente 2020-02-15 17:59:17 +01:00
parent 5afced95c8
commit 53a6297808
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
1 changed files with 35 additions and 22 deletions

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/log"
"encoding/gob"
"fmt"
"net/url"
"strconv"
"strings"
"time"
@ -120,31 +121,43 @@ func initMysqlEngine() (engine *xorm.Engine, err error) {
return
}
// parsePostgreSQLHostPort parses given input in various forms defined in
// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
// and returns proper host and port number.
func parsePostgreSQLHostPort(info string) (string, string) {
host, port := "127.0.0.1", "5432"
if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") {
idx := strings.LastIndex(info, ":")
host = info[:idx]
port = info[idx+1:]
} else if len(info) > 0 {
host = info
}
return host, port
}
func initPostgresEngine() (engine *xorm.Engine, err error) {
var connStr strings.Builder
// https://pkg.go.dev/github.com/lib/pq?tab=doc#hdr-Connection_String_Parameters
params := map[string]string{
"user": config.DatabaseUser.GetString(),
"password": config.DatabasePassword.GetString(),
"host": config.DatabaseHost.GetString(),
"dbname": config.DatabaseDatabase.GetString(),
"sslmode": config.DatabaseSslMode.GetString(),
var connStr string
host, port := parsePostgreSQLHostPort(config.DatabaseHost.GetString())
if strings.HasPrefix(config.DatabaseHost.GetString(), "/") { // looks like a unix socket
connStr = fmt.Sprintf("postgres://%s:%s@:%s/%ssslmode=%s&host=%s",
url.PathEscape(config.DatabaseUser.GetString()),
url.PathEscape(config.DatabasePassword.GetString()),
port,
config.DatabaseDatabase.GetString(),
config.DatabaseSslMode.GetString(),
host)
} else {
connStr = fmt.Sprintf("postgres://%s:%s@%s:%s/%ssslmode=%s",
url.PathEscape(config.DatabaseUser.GetString()),
url.PathEscape(config.DatabasePassword.GetString()),
host,
port,
config.DatabaseDatabase.GetString(),
config.DatabaseSslMode.GetString())
}
for name, value := range params {
if name != "" {
value = strings.ReplaceAll(value, "\\", "\\\\")
value = strings.ReplaceAll(value, "'", "\\'")
if connStr.Len() > 0 {
connStr.WriteString(" ")
}
connStr.WriteString(fmt.Sprintf("%s='%s'", name, value))
}
}
engine, err = xorm.NewEngine("postgres", connStr.String())
engine, err = xorm.NewEngine("postgres", connStr)
if err != nil {
return
}