Use db sessions everywere #750

Merged
konrad merged 44 commits from feature/db-sessions into master 2020-12-23 15:32:29 +00:00
107 changed files with 2428 additions and 1279 deletions

8
go.mod
View File

@ -18,7 +18,7 @@ module code.vikunja.io/api
require (
4d63.com/tz v1.2.0
code.vikunja.io/web v0.0.0-20200809154828-8767618f181f
code.vikunja.io/web v0.0.0-20201223143420-588abb73703a
dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 // indirect
gitea.com/xorm/xorm-redis-cache v0.2.0
github.com/adlio/trello v1.8.0
@ -41,6 +41,7 @@ require (
github.com/go-sql-driver/mysql v1.5.0
github.com/go-testfixtures/testfixtures/v3 v3.4.1
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
github.com/golang/snappy v0.0.2 // indirect
github.com/gordonklaus/ineffassign v0.0.0-20201107091007-3b93a8888063
github.com/iancoleman/strcase v0.1.2
github.com/imdario/mergo v0.3.11
@ -52,6 +53,7 @@ require (
github.com/lib/pq v1.9.0
github.com/magefile/mage v1.10.0
github.com/mailru/easyjson v0.7.6 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-sqlite3 v1.14.5
github.com/mitchellh/mapstructure v1.3.2 // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
@ -76,8 +78,10 @@ require (
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad
golang.org/x/image v0.0.0-20201208152932-35266b937fa6
golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5
golang.org/x/net v0.0.0-20201216054612-986b41b23924 // indirect
golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
golang.org/x/sys v0.0.0-20201223074533-0d417f636930 // indirect
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
@ -91,7 +95,7 @@ require (
src.techknowlogick.com/xormigrate v1.4.0
xorm.io/builder v0.3.7
xorm.io/core v0.7.3
xorm.io/xorm v1.0.2
xorm.io/xorm v1.0.5
)
replace (

25
go.sum
View File

@ -38,8 +38,12 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo
cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
code.vikunja.io/web v0.0.0-20200809154828-8767618f181f h1:Zgtk9lbJkGbKjdTC78mg/c2uNkesxDJs1YUIL9zGvco=
code.vikunja.io/web v0.0.0-20200809154828-8767618f181f/go.mod h1:vDWiCtftF6LNCCrem7mjstPWMgzLUvMW/L4YwIQ1Voo=
code.vikunja.io/web v0.0.0-20201218134444-505d0e77fac7 h1:iS3TFA+y1If6DEbqzad5Ge7TI1NxZr9BevC/dU4ygEo=
code.vikunja.io/web v0.0.0-20201218134444-505d0e77fac7/go.mod h1:vDWiCtftF6LNCCrem7mjstPWMgzLUvMW/L4YwIQ1Voo=
code.vikunja.io/web v0.0.0-20201222144643-6fa2fb587215 h1:O5zMWgcnVDVLaQUawgdsv/jX/4SUUAvSedvRR+5+x2o=
code.vikunja.io/web v0.0.0-20201222144643-6fa2fb587215/go.mod h1:OgFO06HN1KpA4S7Dw/QAIeygiUPSeGJJn1ykz/sjZdU=
code.vikunja.io/web v0.0.0-20201223143420-588abb73703a h1:LaWCucY5Pp30EIMgGOvdVFNss5OhIAwrAO8PuFVRUfw=
code.vikunja.io/web v0.0.0-20201223143420-588abb73703a/go.mod h1:OgFO06HN1KpA4S7Dw/QAIeygiUPSeGJJn1ykz/sjZdU=
dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 h1:o4lAkfETerCnr1kF9/qwkwjICnU+YLHNDCM8h2xj7as=
dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363/go.mod h1:WG7q7swWsS2f9PYpt5DoEP/EBYWx8We5UoRltn9vJl8=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
@ -152,6 +156,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 h1:YcpmyvADGYw5LqMnHqSkyIELsHCGF6PkrmM31V8rF7o=
github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM=
github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204 h1:tI48fqaIkxxYuIylVv1tdDfBp6836GKSfmmzgSyP1CY=
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
@ -293,6 +298,8 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw=
github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
@ -491,6 +498,7 @@ github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU=
github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.8.0 h1:9xohqzkUwzR4Ga4ivdTcawVS89YSDVxXMa3xJX3cGzg=
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8=
@ -516,6 +524,8 @@ github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+v
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw=
github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.4 h1:bnP0vzxcAdeI1zdubAl5PjU6zsERjGZb7raWodagDYs=
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
@ -852,8 +862,6 @@ golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201217014255-9d1352758620 h1:3wPMTskHO3+O6jqTEXyFcsnuxMQOqYSaHsDxcbUXpqA=
golang.org/x/crypto v0.0.0-20201217014255-9d1352758620/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -944,6 +952,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2l
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201216054612-986b41b23924 h1:QsnDpLLOKwHBBDa8nDws4DYNc/ryVW2vCpxCs09d4PY=
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
@ -1028,8 +1038,13 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuF
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e h1:AyodaIpKjppX+cBfTASF2E1US3H2JFBj920Ot3rtDjs=
golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201221093633-bc327ba9c2f0 h1:n+DPcgTwkgWzIFpLmoimYR2K2b0Ga5+Os4kayIN0vGo=
golang.org/x/sys v0.0.0-20201221093633-bc327ba9c2f0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201223074533-0d417f636930 h1:vRgIt+nup/B/BwIS0g2oC0haq0iqbV3ZA+u6+0TlNCo=
golang.org/x/sys v0.0.0-20201223074533-0d417f636930/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M=
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1284,3 +1299,5 @@ xorm.io/xorm v1.0.1 h1:/lITxpJtkZauNpdzj+L9CN/3OQxZaABrbergMcJu+Cw=
xorm.io/xorm v1.0.1/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY=
xorm.io/xorm v1.0.2 h1:kZlCh9rqd1AzGwWitcrEEqHE1h1eaZE/ujU5/2tWEtg=
xorm.io/xorm v1.0.2/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY=
xorm.io/xorm v1.0.5 h1:LRr5PfOUb4ODPR63YwbowkNDwcolT2LnkwP/TUaMaB0=
xorm.io/xorm v1.0.5/go.mod h1:uF9EtbhODq5kNWxMbnBEj8hRRZnlcNSz2t2N7HW/+A4=

View File

@ -24,6 +24,7 @@ import (
"strings"
"time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/initialize"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
@ -31,6 +32,7 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/term"
"xorm.io/xorm"
)
var (
@ -91,13 +93,13 @@ func getPasswordFromFlagOrInput() (pw string) {
return
}
func getUserFromArg(arg string) *user.User {
func getUserFromArg(s *xorm.Session, arg string) *user.User {
id, err := strconv.ParseInt(arg, 10, 64)
if err != nil {
log.Fatalf("Invalid user id: %s", err)
}
u, err := user.GetUserByID(id)
u, err := user.GetUserByID(s, id)
if err != nil {
log.Fatalf("Could not get user: %s", err)
}
@ -116,8 +118,16 @@ var userListCmd = &cobra.Command{
initialize.FullInit()
},
Run: func(cmd *cobra.Command, args []string) {
users, err := user.ListUsers("")
s := db.NewSession()
defer s.Close()
users, err := user.ListUsers(s, "")
if err != nil {
_ = s.Rollback()
log.Fatalf("Error getting users: %s", err)
}
if err := s.Commit(); err != nil {
log.Fatalf("Error getting users: %s", err)
}
@ -153,21 +163,30 @@ var userCreateCmd = &cobra.Command{
initialize.FullInit()
},
Run: func(cmd *cobra.Command, args []string) {
s := db.NewSession()
defer s.Close()
u := &user.User{
Username: userFlagUsername,
Email: userFlagEmail,
Password: getPasswordFromFlagOrInput(),
}
newUser, err := user.CreateUser(u)
newUser, err := user.CreateUser(s, u)
if err != nil {
_ = s.Rollback()
log.Fatalf("Error creating new user: %s", err)
}
err = models.CreateNewNamespaceForUser(newUser)
err = models.CreateNewNamespaceForUser(s, newUser)
if err != nil {
_ = s.Rollback()
log.Fatalf("Error creating new namespace for user: %s", err)
}
if err := s.Commit(); err != nil {
log.Fatalf("Error saving everything: %s", err)
}
fmt.Printf("\nUser was created successfully.\n")
},
}
@ -180,7 +199,10 @@ var userUpdateCmd = &cobra.Command{
initialize.FullInit()
},
Run: func(cmd *cobra.Command, args []string) {
u := getUserFromArg(args[0])
s := db.NewSession()
defer s.Close()
u := getUserFromArg(s, args[0])
if userFlagUsername != "" {
u.Username = userFlagUsername
@ -192,11 +214,16 @@ var userUpdateCmd = &cobra.Command{
u.AvatarProvider = userFlagAvatar
}
_, err := user.UpdateUser(u)
_, err := user.UpdateUser(s, u)
if err != nil {
_ = s.Rollback()
log.Fatalf("Error updating the user: %s", err)
}
if err := s.Commit(); err != nil {
log.Fatalf("Error saving everything: %s", err)
}
fmt.Println("User updated successfully.")
},
}
@ -209,22 +236,31 @@ var userResetPasswordCmd = &cobra.Command{
},
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
u := getUserFromArg(args[0])
s := db.NewSession()
defer s.Close()
u := getUserFromArg(s, args[0])
// By default we reset as usual, only with specific flag directly.
if userFlagResetPasswordDirectly {
err := user.UpdateUserPassword(u, getPasswordFromFlagOrInput())
err := user.UpdateUserPassword(s, u, getPasswordFromFlagOrInput())
if err != nil {
_ = s.Rollback()
log.Fatalf("Could not update user password: %s", err)
}
fmt.Println("Password updated successfully.")
} else {
err := user.RequestUserPasswordResetToken(u)
err := user.RequestUserPasswordResetToken(s, u)
if err != nil {
_ = s.Rollback()
log.Fatalf("Could not send password reset email: %s", err)
}
fmt.Println("Password reset email sent successfully.")
}
if err := s.Commit(); err != nil {
log.Fatalf("Could not send password reset email: %s", err)
}
},
}
@ -236,7 +272,10 @@ var userChangeEnabledCmd = &cobra.Command{
},
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
u := getUserFromArg(args[0])
s := db.NewSession()
defer s.Close()
u := getUserFromArg(s, args[0])
if userFlagEnableUser {
u.IsActive = true
@ -245,11 +284,16 @@ var userChangeEnabledCmd = &cobra.Command{
} else {
u.IsActive = !u.IsActive
}
_, err := user.UpdateUser(u)
_, err := user.UpdateUser(s, u)
if err != nil {
_ = s.Rollback()
log.Fatalf("Could not enable the user")
}
if err := s.Commit(); err != nil {
log.Fatalf("Error saving everything: %s", err)
}
fmt.Printf("User status successfully changed, user is now active: %t.\n", u.IsActive)
},
}

View File

@ -31,6 +31,7 @@ import (
"xorm.io/core"
"xorm.io/xorm"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
_ "github.com/go-sql-driver/mysql" // Because.
_ "github.com/lib/pq" // Because.
@ -211,3 +212,13 @@ func WipeEverything() error {
return nil
}
// NewSession creates a new xorm session
func NewSession() *xorm.Session {
return x.NewSession()
}
// Type returns the db type of the currently configured db
func Type() schemas.DBType {
return x.Dialect().URI().DBType
}

View File

@ -22,6 +22,7 @@ import (
"time"
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/web"
"github.com/c2h5oh/datasize"
"github.com/spf13/afero"
@ -93,27 +94,44 @@ func CreateWithMime(f io.Reader, realname string, realsize uint64, a web.Auth, m
Mime: mime,
}
_, err = x.Insert(file)
s := db.NewSession()
defer s.Close()
_, err = s.Insert(file)
if err != nil {
_ = s.Rollback()
return
}
// Save the file to storage with its new ID as path
err = file.Save(f)
if err != nil {
_ = s.Rollback()
return
}
return
}
// Delete removes a file from the DB and the file system
func (f *File) Delete() (err error) {
deleted, err := x.Where("id = ?", f.ID).Delete(f)
s := db.NewSession()
defer s.Close()
deleted, err := s.Where("id = ?", f.ID).Delete(f)
if err != nil {
_ = s.Rollback()
return err
}
if deleted == 0 {
_ = s.Rollback()
return ErrFileDoesNotExist{FileID: f.ID}
}
err = afs.Remove(f.getFileName())
if err != nil {
_ = s.Rollback()
return err
}
return
}

View File

@ -19,6 +19,7 @@ package models
import (
"code.vikunja.io/web"
"github.com/imdario/mergo"
"xorm.io/xorm"
)
// BulkTask is the definition of a bulk update task
@ -29,9 +30,9 @@ type BulkTask struct {
Task
}
func (bt *BulkTask) checkIfTasksAreOnTheSameList() (err error) {
func (bt *BulkTask) checkIfTasksAreOnTheSameList(s *xorm.Session) (err error) {
// Get the tasks
err = bt.GetTasksByIDs()
err = bt.GetTasksByIDs(s)
if err != nil {
return err
}
@ -52,16 +53,16 @@ func (bt *BulkTask) checkIfTasksAreOnTheSameList() (err error) {
}
// CanUpdate checks if a user is allowed to update a task
func (bt *BulkTask) CanUpdate(a web.Auth) (bool, error) {
func (bt *BulkTask) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
err := bt.checkIfTasksAreOnTheSameList()
err := bt.checkIfTasksAreOnTheSameList(s)
if err != nil {
return false, err
}
// A user can update an task if he has write acces to its list
l := &List{ID: bt.Tasks[0].ListID}
return l.CanWrite(a)
return l.CanWrite(s, a)
}
// Update updates a bunch of tasks at once
@ -77,23 +78,14 @@ func (bt *BulkTask) CanUpdate(a web.Auth) (bool, error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its list)"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/bulk [post]
func (bt *BulkTask) Update() (err error) {
sess := x.NewSession()
defer sess.Close()
err = sess.Begin()
if err != nil {
return
}
func (bt *BulkTask) Update(s *xorm.Session) (err error) {
for _, oldtask := range bt.Tasks {
// When a repeating task is marked as done, we update all deadlines and reminders and set it as undone
updateDone(oldtask, &bt.Task)
// Update the assignees
if err := oldtask.updateTaskAssignees(sess, bt.Assignees); err != nil {
if err := oldtask.updateTaskAssignees(s, bt.Assignees); err != nil {
return err
}
@ -109,7 +101,7 @@ func (bt *BulkTask) Update() (err error) {
oldtask.Done = false
}
_, err = sess.ID(oldtask.ID).
_, err = s.ID(oldtask.ID).
Cols("title",
"description",
"done",
@ -121,15 +113,9 @@ func (bt *BulkTask) Update() (err error) {
"end_date").
Update(oldtask)
if err != nil {
_ = sess.Rollback()
return err
}
}
err = sess.Commit()
if err != nil {
return
}
return
}

View File

@ -57,18 +57,22 @@ func TestBulkTask_Update(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
bt := &BulkTask{
IDs: tt.fields.IDs,
Tasks: tt.fields.Tasks,
Task: tt.fields.Task,
}
allowed, _ := bt.CanUpdate(tt.fields.User)
allowed, _ := bt.CanUpdate(s, tt.fields.User)
if !allowed != tt.wantForbidden {
t.Errorf("BulkTask.Update() want forbidden, got %v, want %v", allowed, tt.wantForbidden)
}
if err := bt.Update(); (err != nil) != tt.wantErr {
if err := bt.Update(s); (err != nil) != tt.wantErr {
t.Errorf("BulkTask.Update() error = %v, wantErr %v", err, tt.wantErr)
}
s.Close()
})
}
}

View File

@ -97,14 +97,14 @@ func getDefaultBucket(s *xorm.Session, listID int64) (bucket *Bucket, err error)
// @Success 200 {array} models.Bucket "The buckets with their tasks"
// @Failure 500 {object} models.Message "Internal server error"
// @Router /lists/{id}/buckets [get]
func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (b *Bucket) ReadAll(s *xorm.Session, auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Note: I'm ignoring pagination for now since I've yet to figure out a way on how to make it work
// I'll probably just don't do it and instead make individual tasks archivable.
// Get all buckets for this list
buckets := []*Bucket{}
err = x.Where("list_id = ?", b.ListID).Find(&buckets)
err = s.Where("list_id = ?", b.ListID).Find(&buckets)
if err != nil {
return
}
@ -119,7 +119,7 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r
// Get all users
users := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&users)
err = s.In("id", userIDs).Find(&users)
if err != nil {
return
}
@ -132,7 +132,7 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r
b.TaskCollection.ListID = b.ListID
b.TaskCollection.OrderBy = []string{string(orderAscending)}
b.TaskCollection.SortBy = []string{taskPropertyPosition}
ts, _, _, err := b.TaskCollection.ReadAll(auth, "", -1, 0)
ts, _, _, err := b.TaskCollection.ReadAll(s, auth, "", -1, 0)
if err != nil {
return
}
@ -168,10 +168,10 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "The list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/buckets [put]
func (b *Bucket) Create(a web.Auth) (err error) {
func (b *Bucket) Create(s *xorm.Session, a web.Auth) (err error) {
b.CreatedByID = a.GetID()
_, err = x.Insert(b)
_, err = s.Insert(b)
return
}
@ -190,8 +190,8 @@ func (b *Bucket) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "The bucket does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/buckets/{bucketID} [post]
func (b *Bucket) Update() (err error) {
_, err = x.Where("id = ?", b.ID).Update(b)
func (b *Bucket) Update(s *xorm.Session) (err error) {
_, err = s.Where("id = ?", b.ID).Update(b)
return
}
@ -208,14 +208,11 @@ func (b *Bucket) Update() (err error) {
// @Failure 404 {object} web.HTTPError "The bucket does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/buckets/{bucketID} [delete]
func (b *Bucket) Delete() (err error) {
s := x.NewSession()
func (b *Bucket) Delete(s *xorm.Session) (err error) {
// Prevent removing the last bucket
total, err := s.Where("list_id = ?", b.ListID).Count(&Bucket{})
if err != nil {
_ = s.Rollback()
return
}
if total <= 1 {
@ -228,23 +225,19 @@ func (b *Bucket) Delete() (err error) {
// Remove the bucket itself
_, err = s.Where("id = ?", b.ID).Delete(&Bucket{})
if err != nil {
_ = s.Rollback()
return
}
// Get the default bucket
defaultBucket, err := getDefaultBucket(s, b.ListID)
if err != nil {
_ = s.Rollback()
return
}
// Remove all associations of tasks to that bucket
_, err = s.Where("bucket_id = ?", b.ID).Cols("bucket_id").Update(&Task{BucketID: defaultBucket.ID})
if err != nil {
_ = s.Rollback()
return
}
return s.Commit()
_, err = s.
Where("bucket_id = ?", b.ID).
Cols("bucket_id").
Update(&Task{BucketID: defaultBucket.ID})
return
}

View File

@ -16,30 +16,33 @@
package models
import "code.vikunja.io/web"
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if a user can create a new bucket
func (b *Bucket) CanCreate(a web.Auth) (bool, error) {
func (b *Bucket) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
l := &List{ID: b.ListID}
return l.CanWrite(a)
return l.CanWrite(s, a)
}
// CanUpdate checks if a user can update an existing bucket
func (b *Bucket) CanUpdate(a web.Auth) (bool, error) {
return b.canDoBucket(a)
func (b *Bucket) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return b.canDoBucket(s, a)
}
// CanDelete checks if a user can delete an existing bucket
func (b *Bucket) CanDelete(a web.Auth) (bool, error) {
return b.canDoBucket(a)
func (b *Bucket) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return b.canDoBucket(s, a)
}
// canDoBucket checks if the bucket exists and if the user has the right to act on it
func (b *Bucket) canDoBucket(a web.Auth) (bool, error) {
bb, err := getBucketByID(x.NewSession(), b.ID)
func (b *Bucket) canDoBucket(s *xorm.Session, a web.Auth) (bool, error) {
bb, err := getBucketByID(s, b.ID)
if err != nil {
return false, err
}
l := &List{ID: bb.ListID}
return l.CanWrite(a)
return l.CanWrite(s, a)
}

View File

@ -27,10 +27,12 @@ import (
func TestBucket_ReadAll(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
testuser := &user.User{ID: 1}
b := &Bucket{ListID: 1}
bucketsInterface, _, _, err := b.ReadAll(testuser, "", 0, 0)
bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", 0, 0)
assert.NoError(t, err)
buckets, is := bucketsInterface.([]*Bucket)
@ -66,6 +68,8 @@ func TestBucket_ReadAll(t *testing.T) {
})
t.Run("filtered", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
testuser := &user.User{ID: 1}
b := &Bucket{
@ -76,7 +80,7 @@ func TestBucket_ReadAll(t *testing.T) {
FilterValue: []string{"done"},
},
}
bucketsInterface, _, _, err := b.ReadAll(testuser, "", 0, 0)
bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", 0, 0)
assert.NoError(t, err)
buckets := bucketsInterface.([]*Bucket)
@ -88,16 +92,21 @@ func TestBucket_ReadAll(t *testing.T) {
func TestBucket_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
b := &Bucket{
ID: 2, // The second bucket only has 3 tasks
ListID: 1,
}
err := b.Delete()
err := b.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
// Assert all tasks have been moved to bucket 1 as that one is the first
tasks := []*Task{}
err = x.Where("bucket_id = ?", 1).Find(&tasks)
err = s.Where("bucket_id = ?", 1).Find(&tasks)
assert.NoError(t, err)
assert.Len(t, tasks, 15)
db.AssertMissing(t, "buckets", map[string]interface{}{
@ -107,13 +116,19 @@ func TestBucket_Delete(t *testing.T) {
})
t.Run("last bucket in list", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
b := &Bucket{
ID: 34,
ListID: 18,
}
err := b.Delete()
err := b.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrCannotRemoveLastBucket(err))
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "buckets", map[string]interface{}{
"id": 34,
"list_id": 18,

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// Label represents a label
@ -64,7 +65,7 @@ func (Label) TableName() string {
// @Failure 400 {object} web.HTTPError "Invalid label object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /labels [put]
func (l *Label) Create(a web.Auth) (err error) {
func (l *Label) Create(s *xorm.Session, a web.Auth) (err error) {
u, err := user.GetFromAuth(a)
if err != nil {
return
@ -73,7 +74,7 @@ func (l *Label) Create(a web.Auth) (err error) {
l.CreatedBy = u
l.CreatedByID = u.ID
_, err = x.Insert(l)
_, err = s.Insert(l)
return
}
@ -92,8 +93,8 @@ func (l *Label) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "Label not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /labels/{id} [put]
func (l *Label) Update() (err error) {
_, err = x.
func (l *Label) Update(s *xorm.Session) (err error) {
_, err = s.
ID(l.ID).
Cols(
"title",
@ -105,7 +106,7 @@ func (l *Label) Update() (err error) {
return
}
err = l.ReadOne()
err = l.ReadOne(s)
return
}
@ -122,8 +123,8 @@ func (l *Label) Update() (err error) {
// @Failure 404 {object} web.HTTPError "Label not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /labels/{id} [delete]
func (l *Label) Delete() (err error) {
_, err = x.ID(l.ID).Delete(&Label{})
func (l *Label) Delete(s *xorm.Session) (err error) {
_, err = s.ID(l.ID).Delete(&Label{})
return err
}
@ -140,7 +141,7 @@ func (l *Label) Delete() (err error) {
// @Success 200 {array} models.Label "The labels"
// @Failure 500 {object} models.Message "Internal error"
// @Router /labels [get]
func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls interface{}, resultCount int, numberOfEntries int64, err error) {
func (l *Label) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (ls interface{}, resultCount int, numberOfEntries int64, err error) {
if _, is := a.(*LinkSharing); is {
return nil, 0, 0, ErrGenericForbidden{}
}
@ -148,12 +149,12 @@ func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls in
u := &user.User{ID: a.GetID()}
// Get all tasks
taskIDs, err := getUserTaskIDs(u)
taskIDs, err := getUserTaskIDs(s, u)
if err != nil {
return nil, 0, 0, err
}
return getLabelsByTaskIDs(&LabelByTaskIDsOptions{
return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
Search: search,
User: u,
TaskIDs: taskIDs,
@ -177,25 +178,25 @@ func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls in
// @Failure 404 {object} web.HTTPError "Label not found"
// @Failure 500 {object} models.Message "Internal error"
// @Router /labels/{id} [get]
func (l *Label) ReadOne() (err error) {
label, err := getLabelByIDSimple(l.ID)
func (l *Label) ReadOne(s *xorm.Session) (err error) {
label, err := getLabelByIDSimple(s, l.ID)
if err != nil {
return err
}
*l = *label
user, err := user.GetUserByID(l.CreatedByID)
u, err := user.GetUserByID(s, l.CreatedByID)
if err != nil {
return err
}
l.CreatedBy = user
l.CreatedBy = u
return
}
func getLabelByIDSimple(labelID int64) (*Label, error) {
func getLabelByIDSimple(s *xorm.Session, labelID int64) (*Label, error) {
label := Label{}
exists, err := x.ID(labelID).Get(&label)
exists, err := s.ID(labelID).Get(&label)
if err != nil {
return &label, err
}
@ -207,18 +208,21 @@ func getLabelByIDSimple(labelID int64) (*Label, error) {
}
// Helper method to get all task ids a user has
func getUserTaskIDs(u *user.User) (taskIDs []int64, err error) {
func getUserTaskIDs(s *xorm.Session, u *user.User) (taskIDs []int64, err error) {
// Get all lists
lists, _, _, err := getRawListsForUser(&listOptions{
user: u,
page: -1,
})
lists, _, _, err := getRawListsForUser(
s,
&listOptions{
user: u,
page: -1,
},
)
if err != nil {
return nil, err
}
tasks, _, _, err := getRawTasksForLists(lists, u, &taskOptions{
tasks, _, _, err := getRawTasksForLists(s, lists, u, &taskOptions{
page: -1,
perPage: 0,
})

View File

@ -20,26 +20,27 @@ import (
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/builder"
"xorm.io/xorm"
)
// CanUpdate checks if a user can update a label
func (l *Label) CanUpdate(a web.Auth) (bool, error) {
return l.isLabelOwner(a) // Only owners should be allowed to update a label
func (l *Label) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return l.isLabelOwner(s, a) // Only owners should be allowed to update a label
}
// CanDelete checks if a user can delete a label
func (l *Label) CanDelete(a web.Auth) (bool, error) {
return l.isLabelOwner(a) // Only owners should be allowed to delete a label
func (l *Label) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return l.isLabelOwner(s, a) // Only owners should be allowed to delete a label
}
// CanRead checks if a user can read a label
func (l *Label) CanRead(a web.Auth) (bool, int, error) {
return l.hasAccessToLabel(a)
func (l *Label) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
return l.hasAccessToLabel(s, a)
}
// CanCreate checks if the user can create a label
// Currently a dummy.
func (l *Label) CanCreate(a web.Auth) (bool, error) {
func (l *Label) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is {
return false, nil
}
@ -47,13 +48,13 @@ func (l *Label) CanCreate(a web.Auth) (bool, error) {
return true, nil
}
func (l *Label) isLabelOwner(a web.Auth) (bool, error) {
func (l *Label) isLabelOwner(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is {
return false, nil
}
lorig, err := getLabelByIDSimple(l.ID)
lorig, err := getLabelByIDSimple(s, l.ID)
if err != nil {
return false, err
}
@ -61,19 +62,19 @@ func (l *Label) isLabelOwner(a web.Auth) (bool, error) {
}
// Helper method to check if a user can see a specific label
func (l *Label) hasAccessToLabel(a web.Auth) (has bool, maxRight int, err error) {
func (l *Label) hasAccessToLabel(s *xorm.Session, a web.Auth) (has bool, maxRight int, err error) {
// TODO: add an extra check for link share handling
// Get all tasks
taskIDs, err := getUserTaskIDs(&user.User{ID: a.GetID()})
taskIDs, err := getUserTaskIDs(s, &user.User{ID: a.GetID()})
if err != nil {
return false, 0, err
}
// Get all labels associated with these tasks
ll := &LabelTask{}
has, err = x.Table("labels").
has, err = s.Table("labels").
Select("label_task.*").
Join("LEFT", "label_task", "label_task.label_id = labels.id").
Where("label_task.label_id is not null OR labels.created_by_id = ?", a.GetID()).
@ -87,7 +88,7 @@ func (l *Label) hasAccessToLabel(a web.Auth) (has bool, maxRight int, err error)
// Since the right depends on the task the label is associated with, we need to check that too.
if ll.TaskID > 0 {
t := &Task{ID: ll.TaskID}
_, maxRight, err = t.CanRead(a)
_, maxRight, err = t.CanRead(s, a)
if err != nil {
return
}

View File

@ -22,10 +22,10 @@ import (
"time"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/builder"
"xorm.io/xorm"
)
// LabelTask represents a relation between a label and a task
@ -61,8 +61,8 @@ func (LabelTask) TableName() string {
// @Failure 404 {object} web.HTTPError "Label not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{task}/labels/{label} [delete]
func (lt *LabelTask) Delete() (err error) {
_, err = x.Delete(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
func (lt *LabelTask) Delete(s *xorm.Session) (err error) {
_, err = s.Delete(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
return err
}
@ -81,9 +81,9 @@ func (lt *LabelTask) Delete() (err error) {
// @Failure 404 {object} web.HTTPError "The label does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{task}/labels [put]
func (lt *LabelTask) Create(a web.Auth) (err error) {
func (lt *LabelTask) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the label is already added
exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
exists, err := s.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
if err != nil {
return err
}
@ -92,12 +92,12 @@ func (lt *LabelTask) Create(a web.Auth) (err error) {
}
// Insert it
_, err = x.Insert(lt)
_, err = s.Insert(lt)
if err != nil {
return err
}
err = updateListByTaskID(lt.TaskID)
err = updateListByTaskID(s, lt.TaskID)
return
}
@ -115,10 +115,10 @@ func (lt *LabelTask) Create(a web.Auth) (err error) {
// @Success 200 {array} models.Label "The labels"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{task}/labels [get]
func (lt *LabelTask) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (lt *LabelTask) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has the right to see the task
task := Task{ID: lt.TaskID}
canRead, _, err := task.CanRead(a)
canRead, _, err := task.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -126,7 +126,7 @@ func (lt *LabelTask) ReadAll(a web.Auth, search string, page int, perPage int) (
return nil, 0, 0, ErrNoRightToSeeTask{lt.TaskID, a.GetID()}
}
return getLabelsByTaskIDs(&LabelByTaskIDsOptions{
return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
User: &user.User{ID: a.GetID()},
Search: search,
Page: page,
@ -153,7 +153,7 @@ type LabelByTaskIDsOptions struct {
// Helper function to get all labels for a set of tasks
// Used when getting all labels for one task as well when getting all lables
func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, resultCount int, totalEntries int64, err error) {
func getLabelsByTaskIDs(s *xorm.Session, opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, resultCount int, totalEntries int64, err error) {
// We still need the task ID when we want to get all labels for a task, but because of this, we get the same label
// multiple times when it is associated to more than one task.
// Because of this whole thing, we need this extra switch here to only group by Task IDs if needed.
@ -194,7 +194,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
limit, start := getLimitFromPageIndex(opts.Page, opts.PerPage)
query := x.Table("labels").
query := s.Table("labels").
Select(selectStmt).
Join("LEFT", "label_task", "label_task.label_id = labels.id").
Where(cond).
@ -214,7 +214,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
userids = append(userids, l.CreatedByID)
}
users := make(map[int64]*user.User)
err = x.In("id", userids).Find(&users)
err = s.In("id", userids).Find(&users)
if err != nil {
return nil, 0, 0, err
}
@ -230,7 +230,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
}
// Get the total number of entries
totalEntries, err = x.Table("labels").
totalEntries, err = s.Table("labels").
Select("count(DISTINCT labels.id)").
Join("LEFT", "label_task", "label_task.label_id = labels.id").
Where(cond).
@ -244,11 +244,11 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
}
// Create or update a bunch of task labels
func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
func (t *Task) updateTaskLabels(s *xorm.Session, creator web.Auth, labels []*Label) (err error) {
// If we don't have any new labels, delete everything right away. Saves us some hassle.
if len(labels) == 0 && len(t.Labels) > 0 {
_, err = x.Where("task_id = ?", t.ID).
_, err = s.Where("task_id = ?", t.ID).
Delete(LabelTask{})
return err
}
@ -289,7 +289,7 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
// Delete all labels not passed
if len(labelsToDelete) > 0 {
_, err = x.In("label_id", labelsToDelete).
_, err = s.In("label_id", labelsToDelete).
And("task_id = ?", t.ID).
Delete(LabelTask{})
if err != nil {
@ -306,13 +306,13 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
}
// Add the new label
label, err := getLabelByIDSimple(l.ID)
label, err := getLabelByIDSimple(s, l.ID)
if err != nil {
return err
}
// Check if the user has the rights to see the label he is about to add
hasAccessToLabel, _, err := label.hasAccessToLabel(creator)
hasAccessToLabel, _, err := label.hasAccessToLabel(s, creator)
if err != nil {
return err
}
@ -322,14 +322,14 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
}
// Insert it
_, err = x.Insert(&LabelTask{LabelID: l.ID, TaskID: t.ID})
_, err = s.Insert(&LabelTask{LabelID: l.ID, TaskID: t.ID})
if err != nil {
return err
}
t.Labels = append(t.Labels, label)
}
err = updateListLastUpdated(&List{ID: t.ListID})
err = updateListLastUpdated(s, &List{ID: t.ListID})
return
}
@ -356,12 +356,12 @@ type LabelTaskBulk struct {
// @Failure 400 {object} web.HTTPError "Invalid label object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/labels/bulk [post]
func (ltb *LabelTaskBulk) Create(a web.Auth) (err error) {
task, err := GetTaskByIDSimple(ltb.TaskID)
func (ltb *LabelTaskBulk) Create(s *xorm.Session, a web.Auth) (err error) {
task, err := GetTaskByIDSimple(s, ltb.TaskID)
if err != nil {
return
}
labels, _, _, err := getLabelsByTaskIDs(&LabelByTaskIDsOptions{
labels, _, _, err := getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
TaskIDs: []int64{ltb.TaskID},
})
if err != nil {
@ -370,5 +370,5 @@ func (ltb *LabelTaskBulk) Create(a web.Auth) (err error) {
for _, l := range labels {
task.Labels = append(task.Labels, &l.Label)
}
return task.updateTaskLabels(a, ltb.Labels)
return task.updateTaskLabels(s, a, ltb.Labels)
}

View File

@ -18,21 +18,22 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if a user can add a label to a task
func (lt *LabelTask) CanCreate(a web.Auth) (bool, error) {
label, err := getLabelByIDSimple(lt.LabelID)
func (lt *LabelTask) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
label, err := getLabelByIDSimple(s, lt.LabelID)
if err != nil {
return false, err
}
hasAccessTolabel, _, err := label.hasAccessToLabel(a)
hasAccessTolabel, _, err := label.hasAccessToLabel(s, a)
if err != nil || !hasAccessTolabel { // If the user doesn't have access to the label, we can error out here
return false, err
}
canDoLabelTask, err := canDoLabelTask(lt.TaskID, a)
canDoLabelTask, err := canDoLabelTask(s, lt.TaskID, a)
if err != nil {
return false, err
}
@ -41,8 +42,8 @@ func (lt *LabelTask) CanCreate(a web.Auth) (bool, error) {
}
// CanDelete checks if a user can delete a label from a task
func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) {
canDoLabelTask, err := canDoLabelTask(lt.TaskID, a)
func (lt *LabelTask) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
canDoLabelTask, err := canDoLabelTask(s, lt.TaskID, a)
if err != nil {
return false, err
}
@ -52,7 +53,7 @@ func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) {
// We don't care here if the label exists or not. The only relevant thing here is if the relation already exists,
// throw an error.
exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
exists, err := s.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
if err != nil {
return false, err
}
@ -60,18 +61,18 @@ func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) {
}
// CanCreate determines if a user can update a labeltask
func (ltb *LabelTaskBulk) CanCreate(a web.Auth) (bool, error) {
return canDoLabelTask(ltb.TaskID, a)
func (ltb *LabelTaskBulk) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return canDoLabelTask(s, ltb.TaskID, a)
}
// Helper function to check if a user can write to a task
// + is able to see the label
// always the same check for either deleting or adding a label to a task
func canDoLabelTask(taskID int64, a web.Auth) (bool, error) {
func canDoLabelTask(s *xorm.Session, taskID int64, a web.Auth) (bool, error) {
// A user can add a label to a task if he can write to the task
task, err := GetTaskByIDSimple(taskID)
task, err := GetTaskByIDSimple(s, taskID)
if err != nil {
return false, err
}
return task.CanUpdate(a)
return task.CanUpdate(s, a)
}

View File

@ -91,6 +91,7 @@ func TestLabelTask_ReadAll(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
l := &LabelTask{
ID: tt.fields.ID,
@ -100,7 +101,7 @@ func TestLabelTask_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
gotLabels, _, _, err := l.ReadAll(tt.args.a, tt.args.search, tt.args.page, 0)
gotLabels, _, _, err := l.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 0)
if (err != nil) != tt.wantErr {
t.Errorf("LabelTask.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
return
@ -111,6 +112,8 @@ func TestLabelTask_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(gotLabels, tt.wantLabels); !equal {
t.Errorf("LabelTask.ReadAll() = %v, want %v, diff: %v", l, tt.wantLabels, diff)
}
s.Close()
})
}
}
@ -186,6 +189,8 @@ func TestLabelTask_Create(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
l := &LabelTask{
ID: tt.fields.ID,
TaskID: tt.fields.TaskID,
@ -194,11 +199,11 @@ func TestLabelTask_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
allowed, _ := l.CanCreate(tt.args.a)
allowed, _ := l.CanCreate(s, tt.args.a)
if !allowed && !tt.wantForbidden {
t.Errorf("LabelTask.CanCreate() forbidden, want %v", tt.wantForbidden)
}
err := l.Create(tt.args.a)
err := l.Create(s, tt.args.a)
if (err != nil) != tt.wantErr {
t.Errorf("LabelTask.Create() error = %v, wantErr %v", err, tt.wantErr)
}
@ -212,6 +217,7 @@ func TestLabelTask_Create(t *testing.T) {
"label_id": l.LabelID,
}, false)
}
s.Close()
})
}
}
@ -282,6 +288,8 @@ func TestLabelTask_Delete(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
l := &LabelTask{
ID: tt.fields.ID,
TaskID: tt.fields.TaskID,
@ -290,11 +298,11 @@ func TestLabelTask_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
allowed, _ := l.CanDelete(tt.auth)
allowed, _ := l.CanDelete(s, tt.auth)
if !allowed && !tt.wantForbidden {
t.Errorf("LabelTask.CanDelete() forbidden, want %v", tt.wantForbidden)
}
err := l.Delete()
err := l.Delete(s)
if (err != nil) != tt.wantErr {
t.Errorf("LabelTask.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
@ -307,6 +315,7 @@ func TestLabelTask_Delete(t *testing.T) {
"task_id": l.TaskID,
})
}
s.Close()
})
}
}

View File

@ -133,7 +133,8 @@ func TestLabel_ReadAll(t *testing.T) {
Rights: tt.fields.Rights,
}
db.LoadAndAssertFixtures(t)
gotLs, _, _, err := l.ReadAll(tt.args.a, tt.args.search, tt.args.page, 0)
s := db.NewSession()
gotLs, _, _, err := l.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 0)
if (err != nil) != tt.wantErr {
t.Errorf("Label.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
return
@ -141,6 +142,7 @@ func TestLabel_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(gotLs, tt.wantLs); !equal {
t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", gotLs, tt.wantLs, diff)
}
s.Close()
})
}
}
@ -249,11 +251,13 @@ func TestLabel_ReadOne(t *testing.T) {
Rights: tt.fields.Rights,
}
allowed, _, _ := l.CanRead(tt.auth)
s := db.NewSession()
allowed, _, _ := l.CanRead(s, tt.auth)
if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanRead() forbidden, want %v", tt.wantForbidden)
}
err := l.ReadOne()
err := l.ReadOne(s)
if (err != nil) != tt.wantErr {
t.Errorf("Label.ReadOne() error = %v, wantErr %v", err, tt.wantErr)
}
@ -263,6 +267,8 @@ func TestLabel_ReadOne(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(l, tt.want); !equal && !tt.wantErr && !tt.wantForbidden {
t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", l, tt.want, diff)
}
s.Close()
})
}
}
@ -316,11 +322,12 @@ func TestLabel_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
allowed, _ := l.CanCreate(tt.args.a)
s := db.NewSession()
allowed, _ := l.CanCreate(s, tt.args.a)
if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanCreate() forbidden, want %v", tt.wantForbidden)
}
if err := l.Create(tt.args.a); (err != nil) != tt.wantErr {
if err := l.Create(s, tt.args.a); (err != nil) != tt.wantErr {
t.Errorf("Label.Create() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
@ -331,6 +338,7 @@ func TestLabel_Create(t *testing.T) {
"hex_color": l.HexColor,
}, false)
}
_ = s.Close()
})
}
}
@ -406,11 +414,12 @@ func TestLabel_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
allowed, _ := l.CanUpdate(tt.auth)
s := db.NewSession()
allowed, _ := l.CanUpdate(s, tt.auth)
if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanUpdate() forbidden, want %v", tt.wantForbidden)
}
if err := l.Update(); (err != nil) != tt.wantErr {
if err := l.Update(s); (err != nil) != tt.wantErr {
t.Errorf("Label.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && !tt.wantForbidden {
@ -419,6 +428,7 @@ func TestLabel_Update(t *testing.T) {
"title": tt.fields.Title,
}, false)
}
_ = s.Close()
})
}
}
@ -490,11 +500,12 @@ func TestLabel_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
allowed, _ := l.CanDelete(tt.auth)
s := db.NewSession()
allowed, _ := l.CanDelete(s, tt.auth)
if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanDelete() forbidden, want %v", tt.wantForbidden)
}
if err := l.Delete(); (err != nil) != tt.wantErr {
if err := l.Delete(s); (err != nil) != tt.wantErr {
t.Errorf("Label.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && !tt.wantForbidden {
@ -502,6 +513,7 @@ func TestLabel_Delete(t *testing.T) {
"id": l.ID,
})
}
_ = s.Close()
})
}
}

View File

@ -24,6 +24,7 @@ import (
"code.vikunja.io/api/pkg/utils"
"code.vikunja.io/web"
"github.com/dgrijalva/jwt-go"
"xorm.io/xorm"
)
// SharingType holds the sharing type
@ -99,7 +100,7 @@ func GetLinkShareFromClaims(claims jwt.MapClaims) (share *LinkSharing, err error
// @Failure 404 {object} web.HTTPError "The list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares [put]
func (share *LinkSharing) Create(a web.Auth) (err error) {
func (share *LinkSharing) Create(s *xorm.Session, a web.Auth) (err error) {
err = share.Right.isValid()
if err != nil {
@ -108,7 +109,7 @@ func (share *LinkSharing) Create(a web.Auth) (err error) {
share.SharedByID = a.GetID()
share.Hash = utils.MakeRandomString(40)
_, err = x.Insert(share)
_, err = s.Insert(share)
share.SharedBy, _ = user.GetFromAuth(a)
return
}
@ -127,8 +128,8 @@ func (share *LinkSharing) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "Share Link not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares/{share} [get]
func (share *LinkSharing) ReadOne() (err error) {
exists, err := x.Where("id = ?", share.ID).Get(share)
func (share *LinkSharing) ReadOne(s *xorm.Session) (err error) {
exists, err := s.Where("id = ?", share.ID).Get(share)
if err != nil {
return err
}
@ -152,9 +153,9 @@ func (share *LinkSharing) ReadOne() (err error) {
// @Success 200 {array} models.LinkSharing "The share links"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares [get]
func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
func (share *LinkSharing) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
list := &List{ID: share.ListID}
can, _, err := list.CanRead(a)
can, _, err := list.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -165,7 +166,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
limit, start := getLimitFromPageIndex(page, perPage)
var shares []*LinkSharing
query := x.
query := s.
Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%")
if limit > 0 {
query = query.Limit(limit, start)
@ -182,7 +183,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
}
users := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&users)
err = s.In("id", userIDs).Find(&users)
if err != nil {
return nil, 0, 0, err
}
@ -192,7 +193,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
}
// Total count
totalItems, err = x.
totalItems, err = s.
Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%").
Count(&LinkSharing{})
if err != nil {
@ -216,15 +217,15 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
// @Failure 404 {object} web.HTTPError "Share Link not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares/{share} [delete]
func (share *LinkSharing) Delete() (err error) {
_, err = x.Where("id = ?", share.ID).Delete(share)
func (share *LinkSharing) Delete(s *xorm.Session) (err error) {
_, err = s.Where("id = ?", share.ID).Delete(share)
return
}
// GetLinkShareByHash returns a link share by hash
func GetLinkShareByHash(hash string) (share *LinkSharing, err error) {
func GetLinkShareByHash(s *xorm.Session, hash string) (share *LinkSharing, err error) {
share = &LinkSharing{}
has, err := x.Where("hash = ?", hash).Get(share)
has, err := s.Where("hash = ?", hash).Get(share)
if err != nil {
return
}
@ -235,13 +236,12 @@ func GetLinkShareByHash(hash string) (share *LinkSharing, err error) {
}
// GetListByShareHash returns a link share by its hash
func GetListByShareHash(hash string) (list *List, err error) {
share, err := GetLinkShareByHash(hash)
func GetListByShareHash(s *xorm.Session, hash string) (list *List, err error) {
share, err := GetLinkShareByHash(s, hash)
if err != nil {
return
}
list = &List{ID: share.ListID}
err = list.GetSimpleByID()
list, err = GetListSimpleByID(s, share.ListID)
return
}

View File

@ -16,53 +16,55 @@
package models
import "code.vikunja.io/web"
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead implements the read right check for a link share
func (share *LinkSharing) CanRead(a web.Auth) (bool, int, error) {
func (share *LinkSharing) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
// Don't allow creating link shares if the user itself authenticated with a link share
if _, is := a.(*LinkSharing); is {
return false, 0, nil
}
l, err := GetListByShareHash(share.Hash)
l, err := GetListByShareHash(s, share.Hash)
if err != nil {
return false, 0, err
}
return l.CanRead(a)
return l.CanRead(s, a)
}
// CanDelete implements the delete right check for a link share
func (share *LinkSharing) CanDelete(a web.Auth) (bool, error) {
return share.canDoLinkShare(a)
func (share *LinkSharing) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return share.canDoLinkShare(s, a)
}
// CanUpdate implements the update right check for a link share
func (share *LinkSharing) CanUpdate(a web.Auth) (bool, error) {
return share.canDoLinkShare(a)
func (share *LinkSharing) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return share.canDoLinkShare(s, a)
}
// CanCreate implements the create right check for a link share
func (share *LinkSharing) CanCreate(a web.Auth) (bool, error) {
return share.canDoLinkShare(a)
func (share *LinkSharing) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return share.canDoLinkShare(s, a)
}
func (share *LinkSharing) canDoLinkShare(a web.Auth) (bool, error) {
func (share *LinkSharing) canDoLinkShare(s *xorm.Session, a web.Auth) (bool, error) {
// Don't allow creating link shares if the user itself authenticated with a link share
if _, is := a.(*LinkSharing); is {
return false, nil
}
l := &List{ID: share.ListID}
err := l.GetSimpleByID()
l, err := GetListSimpleByID(s, share.ListID)
if err != nil {
return false, err
}
// Check if the user is admin when the link right is admin
if share.Right == RightAdmin {
return l.IsAdmin(a)
return l.IsAdmin(s, a)
}
return l.CanWrite(a)
return l.CanWrite(s, a)
}

View File

@ -96,9 +96,9 @@ var FavoritesPseudoList = List{
}
// GetListsByNamespaceID gets all lists in a namespace
func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error) {
func GetListsByNamespaceID(s *xorm.Session, nID int64, doer *user.User) (lists []*List, err error) {
if nID == -1 {
err = x.Select("l.*").
err = s.Select("l.*").
Table("list").
Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id").
Join("LEFT", []string{"team_members", "tm"}, "tm.team_id = tl.team_id").
@ -111,7 +111,7 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error
GroupBy("l.id").
Find(&lists)
} else {
err = x.Select("l.*").
err = s.Select("l.*").
Alias("l").
Join("LEFT", []string{"namespaces", "n"}, "l.namespace_id = n.id").
Where("l.is_archived = false").
@ -124,7 +124,7 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error
}
// get more list details
err = AddListDetails(lists)
err = addListDetails(s, lists)
return lists, err
}
@ -143,33 +143,34 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists [get]
func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
func (l *List) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
// Check if we're dealing with a share auth
shareAuth, ok := a.(*LinkSharing)
if ok {
list := &List{ID: shareAuth.ListID}
err := list.GetSimpleByID()
list, err := GetListSimpleByID(s, shareAuth.ListID)
if err != nil {
return nil, 0, 0, err
}
lists := []*List{list}
err = AddListDetails(lists)
err = addListDetails(s, lists)
return lists, 0, 0, err
}
lists, resultCount, totalItems, err := getRawListsForUser(&listOptions{
search: search,
user: &user.User{ID: a.GetID()},
page: page,
perPage: perPage,
isArchived: l.IsArchived,
})
lists, resultCount, totalItems, err := getRawListsForUser(
s,
&listOptions{
search: search,
user: &user.User{ID: a.GetID()},
page: page,
perPage: perPage,
isArchived: l.IsArchived,
})
if err != nil {
return nil, 0, 0, err
}
// Add more list details
err = AddListDetails(lists)
err = addListDetails(s, lists)
return lists, resultCount, totalItems, err
}
@ -185,7 +186,7 @@ func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [get]
func (l *List) ReadOne() (err error) {
func (l *List) ReadOne(s *xorm.Session) (err error) {
if l.ID == FavoritesPseudoList.ID {
// Already "built" the list in CanRead
@ -194,7 +195,7 @@ func (l *List) ReadOne() (err error) {
// Check for saved filters
if getSavedFilterIDFromListID(l.ID) > 0 {
sf, err := getSavedFilterSimpleByID(getSavedFilterIDFromListID(l.ID))
sf, err := getSavedFilterSimpleByID(s, getSavedFilterIDFromListID(l.ID))
if err != nil {
return err
}
@ -206,13 +207,13 @@ func (l *List) ReadOne() (err error) {
}
// Get list owner
l.Owner, err = user.GetUserByID(l.OwnerID)
l.Owner, err = user.GetUserByID(s, l.OwnerID)
if err != nil {
return err
}
// Check if the namespace is archived and set the namespace to archived if it is not already archived individually.
if !l.IsArchived {
err = l.CheckIsArchived()
err = l.CheckIsArchived(s)
if err != nil {
if !IsErrNamespaceIsArchived(err) && !IsErrListIsArchived(err) {
return
@ -224,7 +225,7 @@ func (l *List) ReadOne() (err error) {
// Get any background information if there is one set
if l.BackgroundFileID != 0 {
// Unsplash image
l.BackgroundInformation, err = GetUnsplashPhotoByFileID(l.BackgroundFileID)
l.BackgroundInformation, err = GetUnsplashPhotoByFileID(s, l.BackgroundFileID)
if err != nil && !files.IsErrFileIsNotUnsplashFile(err) {
return
}
@ -237,44 +238,33 @@ func (l *List) ReadOne() (err error) {
return nil
}
// GetSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist.
func (l *List) GetSimpleByID() (err error) {
s := x.NewSession()
err = l.getSimpleByID(s)
if err != nil {
_ = s.Rollback()
return err
}
return nil
}
// GetListSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist.
func GetListSimpleByID(s *xorm.Session, listID int64) (list *List, err error) {
func (l *List) getSimpleByID(s *xorm.Session) (err error) {
if l.ID < 1 {
return ErrListDoesNotExist{ID: l.ID}
list = &List{}
if listID < 1 {
return nil, ErrListDoesNotExist{ID: listID}
}
// We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object,
// leading to not finding anything if the id is good, but for example the title is different.
id := l.ID
*l = List{}
exists, err := s.Where("id = ?", id).Get(l)
exists, err := s.Where("id = ?", listID).Get(list)
if err != nil {
return
}
if !exists {
return ErrListDoesNotExist{ID: l.ID}
return nil, ErrListDoesNotExist{ID: listID}
}
return
}
// GetListSimplByTaskID gets a list by a task id
func GetListSimplByTaskID(taskID int64) (l *List, err error) {
func GetListSimplByTaskID(s *xorm.Session, taskID int64) (l *List, err error) {
// We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object,
// leading to not finding anything if the id is good, but for example the title is different.
var list List
exists, err := x.
exists, err := s.
Select("list.*").
Table(List{}).
Join("INNER", "tasks", "list.id = tasks.list_id").
@ -292,9 +282,9 @@ func GetListSimplByTaskID(taskID int64) (l *List, err error) {
}
// GetListsByIDs returns a map of lists from a slice with list ids
func GetListsByIDs(listIDs []int64) (lists map[int64]*List, err error) {
func GetListsByIDs(s *xorm.Session, listIDs []int64) (lists map[int64]*List, err error) {
lists = make(map[int64]*List, len(listIDs))
err = x.In("id", listIDs).Find(&lists)
err = s.In("id", listIDs).Find(&lists)
return
}
@ -307,8 +297,8 @@ type listOptions struct {
}
// Gets the lists only, without any tasks or so
func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) {
fullUser, err := user.GetUserByID(opts.user.ID)
func getRawListsForUser(s *xorm.Session, opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) {
fullUser, err := user.GetUserByID(s, opts.user.ID)
if err != nil {
return nil, 0, 0, err
}
@ -344,7 +334,7 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota
// Gets all Lists where the user is either owner or in a team which has access to the list
// Or in a team which has namespace read access
query := x.Select("l.*").
query := s.Select("l.*").
Table("list").
Alias("l").
Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id").
@ -372,7 +362,7 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota
return nil, 0, 0, err
}
totalItems, err = x.
totalItems, err = s.
Table("list").
Alias("l").
Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id").
@ -396,8 +386,8 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota
return lists, len(lists), totalItems, err
}
// AddListDetails adds owner user objects and list tasks to all lists in the slice
func AddListDetails(lists []*List) (err error) {
// addListDetails adds owner user objects and list tasks to all lists in the slice
func addListDetails(s *xorm.Session, lists []*List) (err error) {
var ownerIDs []int64
for _, l := range lists {
ownerIDs = append(ownerIDs, l.OwnerID)
@ -405,7 +395,7 @@ func AddListDetails(lists []*List) (err error) {
// Get all list owners
owners := map[int64]*user.User{}
err = x.In("id", ownerIDs).Find(&owners)
err = s.In("id", ownerIDs).Find(&owners)
if err != nil {
return
}
@ -423,7 +413,7 @@ func AddListDetails(lists []*List) (err error) {
// Unsplash background file info
us := []*UnsplashPhoto{}
err = x.In("file_id", fileIDs).Find(&us)
err = s.In("file_id", fileIDs).Find(&us)
if err != nil {
return
}
@ -450,15 +440,15 @@ type NamespaceList struct {
}
// CheckIsArchived returns an ErrListIsArchived or ErrNamespaceIsArchived if the list or its namespace is archived.
func (l *List) CheckIsArchived() (err error) {
func (l *List) CheckIsArchived(s *xorm.Session) (err error) {
// When creating a new list, we check if the namespace is archived
if l.ID == 0 {
n := &Namespace{ID: l.NamespaceID}
return n.CheckIsArchived()
return n.CheckIsArchived(s)
}
nl := &NamespaceList{}
exists, err := x.
exists, err := s.
Table("list").
Join("LEFT", "namespaces", "list.namespace_id = namespaces.id").
Where("list.id = ? AND (list.is_archived = true OR namespaces.is_archived = true)", l.ID).
@ -476,11 +466,11 @@ func (l *List) CheckIsArchived() (err error) {
}
// CreateOrUpdateList updates a list or creates it if it doesn't exist
func CreateOrUpdateList(list *List) (err error) {
func CreateOrUpdateList(s *xorm.Session, list *List) (err error) {
// Check if the namespace exists
if list.NamespaceID != 0 && list.NamespaceID != FavoritesPseudoNamespace.ID {
_, err = GetNamespaceByID(list.NamespaceID)
_, err = GetNamespaceByID(s, list.NamespaceID)
if err != nil {
return err
}
@ -488,7 +478,7 @@ func CreateOrUpdateList(list *List) (err error) {
// Check if the identifier is unique and not empty
if list.Identifier != "" {
exists, err := x.
exists, err := s.
Where("identifier = ?", list.Identifier).
And("id != ?", list.ID).
Exist(&List{})
@ -501,7 +491,7 @@ func CreateOrUpdateList(list *List) (err error) {
}
if list.ID == 0 {
_, err = x.Insert(list)
_, err = s.Insert(list)
metrics.UpdateCount(1, metrics.ListCountKey)
} else {
// We need to specify the cols we want to update here to be able to un-archive lists
@ -516,7 +506,7 @@ func CreateOrUpdateList(list *List) (err error) {
colsToUpdate = append(colsToUpdate, "description")
}
_, err = x.
_, err = s.
ID(list.ID).
Cols(colsToUpdate...).
Update(list)
@ -526,12 +516,13 @@ func CreateOrUpdateList(list *List) (err error) {
return
}
err = list.GetSimpleByID()
l, err := GetListSimpleByID(s, list.ID)
if err != nil {
return
return err
}
err = list.ReadOne()
*list = *l
err = list.ReadOne(s)
return
}
@ -550,33 +541,23 @@ func CreateOrUpdateList(list *List) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [post]
func (l *List) Update() (err error) {
return CreateOrUpdateList(l)
func (l *List) Update(s *xorm.Session) (err error) {
return CreateOrUpdateList(s, l)
}
func updateListLastUpdated(list *List) (err error) {
s := x.NewSession()
err = updateListLastUpdatedS(s, list)
if err != nil {
_ = s.Rollback()
return err
}
return nil
}
func updateListLastUpdatedS(s *xorm.Session, list *List) error {
func updateListLastUpdated(s *xorm.Session, list *List) error {
_, err := s.ID(list.ID).Cols("updated").Update(list)
return err
}
func updateListByTaskID(taskID int64) (err error) {
func updateListByTaskID(s *xorm.Session, taskID int64) (err error) {
// need to get the task to update the list last updated timestamp
task, err := GetTaskByIDSimple(taskID)
task, err := GetTaskByIDSimple(s, taskID)
if err != nil {
return err
}
return updateListLastUpdated(&List{ID: task.ListID})
return updateListLastUpdated(s, &List{ID: task.ListID})
}
// Create implements the create method of CRUDable
@ -593,8 +574,8 @@ func updateListByTaskID(taskID int64) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/lists [put]
func (l *List) Create(a web.Auth) (err error) {
err = l.CheckIsArchived()
func (l *List) Create(s *xorm.Session, a web.Auth) (err error) {
err = l.CheckIsArchived(s)
if err != nil {
return err
}
@ -608,7 +589,7 @@ func (l *List) Create(a web.Auth) (err error) {
l.Owner = doer
l.ID = 0 // Otherwise only the first time a new list would be created
err = CreateOrUpdateList(l)
err = CreateOrUpdateList(s, l)
if err != nil {
return
}
@ -618,7 +599,7 @@ func (l *List) Create(a web.Auth) (err error) {
ListID: l.ID,
Title: "New Bucket",
}
return b.Create(a)
return b.Create(s, a)
}
// Delete implements the delete method of CRUDable
@ -633,27 +614,27 @@ func (l *List) Create(a web.Auth) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [delete]
func (l *List) Delete() (err error) {
func (l *List) Delete(s *xorm.Session) (err error) {
// Delete the list
_, err = x.ID(l.ID).Delete(&List{})
_, err = s.ID(l.ID).Delete(&List{})
if err != nil {
return
}
metrics.UpdateCount(-1, metrics.ListCountKey)
// Delete all todotasks on that list
_, err = x.Where("list_id = ?", l.ID).Delete(&Task{})
_, err = s.Where("list_id = ?", l.ID).Delete(&Task{})
return
}
// SetListBackground sets a background file as list background in the db
func SetListBackground(listID int64, background *files.File) (err error) {
func SetListBackground(s *xorm.Session, listID int64, background *files.File) (err error) {
l := &List{
ID: listID,
BackgroundFileID: background.ID,
}
_, err = x.
_, err = s.
Where("id = ?", l.ID).
Cols("background_file_id").
Update(l)

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/utils"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// ListDuplicate holds everything needed to duplicate a list
@ -38,17 +39,17 @@ type ListDuplicate struct {
}
// CanCreate checks if a user has the right to duplicate a list
func (ld *ListDuplicate) CanCreate(a web.Auth) (canCreate bool, err error) {
func (ld *ListDuplicate) CanCreate(s *xorm.Session, a web.Auth) (canCreate bool, err error) {
// List Exists + user has read access to list
ld.List = &List{ID: ld.ListID}
canRead, _, err := ld.List.CanRead(a)
canRead, _, err := ld.List.CanRead(s, a)
if err != nil || !canRead {
return canRead, err
}
// Namespace exists + user has write access to is (-> can create new lists)
ld.List.NamespaceID = ld.NamespaceID
return ld.List.CanCreate(a)
return ld.List.CanCreate(s, a)
}
// Create duplicates a list
@ -66,7 +67,7 @@ func (ld *ListDuplicate) CanCreate(a web.Auth) (canCreate bool, err error) {
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/duplicate [put]
//nolint:gocyclo
func (ld *ListDuplicate) Create(a web.Auth) (err error) {
func (ld *ListDuplicate) Create(s *xorm.Session, a web.Auth) (err error) {
log.Debugf("Duplicating list %d", ld.ListID)
@ -74,7 +75,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
ld.List.Identifier = "" // Reset the identifier to trigger regenerating a new one
// Set the owner to the current user
ld.List.OwnerID = a.GetID()
if err := CreateOrUpdateList(ld.List); err != nil {
if err := CreateOrUpdateList(s, ld.List); err != nil {
// If there is no available unique list identifier, just reset it.
if IsErrListIdentifierIsNotUnique(err) {
ld.List.Identifier = ""
@ -90,7 +91,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Used to map the newly created tasks to their new buckets
bucketMap := make(map[int64]int64)
buckets := []*Bucket{}
err = x.Where("list_id = ?", ld.ListID).Find(&buckets)
err = s.Where("list_id = ?", ld.ListID).Find(&buckets)
if err != nil {
return
}
@ -98,7 +99,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
oldID := b.ID
b.ID = 0
b.ListID = ld.List.ID
if err := b.Create(a); err != nil {
if err := b.Create(s, a); err != nil {
return err
}
bucketMap[oldID] = b.ID
@ -107,7 +108,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
log.Debugf("Duplicated all buckets from list %d into %d", ld.ListID, ld.List.ID)
// Get all tasks + all task details
tasks, _, _, err := getTasksForLists([]*List{{ID: ld.ListID}}, a, &taskOptions{})
tasks, _, _, err := getTasksForLists(s, []*List{{ID: ld.ListID}}, a, &taskOptions{})
if err != nil {
return err
}
@ -123,10 +124,8 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
t.ListID = ld.List.ID
t.BucketID = bucketMap[t.BucketID]
t.UID = ""
s := x.NewSession()
err := createTask(s, t, a, false)
if err != nil {
_ = s.Rollback()
return err
}
taskMap[oldID] = t.ID
@ -138,7 +137,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Save all attachments
// We also duplicate all underlying files since they could be modified in one list which would result in
// file changes in the other list which is not something we want.
attachments, err := getTaskAttachmentsByTaskIDs(oldTaskIDs)
attachments, err := getTaskAttachmentsByTaskIDs(s, oldTaskIDs)
if err != nil {
return err
}
@ -164,7 +163,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
return err
}
err := attachment.NewAttachment(attachment.File.File, attachment.File.Name, attachment.File.Size, a)
err := attachment.NewAttachment(s, attachment.File.File, attachment.File.Name, attachment.File.Size, a)
if err != nil {
return err
}
@ -180,7 +179,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Copy label tasks (not the labels)
labelTasks := []*LabelTask{}
err = x.In("task_id", oldTaskIDs).Find(&labelTasks)
err = s.In("task_id", oldTaskIDs).Find(&labelTasks)
if err != nil {
return
}
@ -188,7 +187,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
for _, lt := range labelTasks {
lt.ID = 0
lt.TaskID = taskMap[lt.TaskID]
if _, err := x.Insert(lt); err != nil {
if _, err := s.Insert(lt); err != nil {
return err
}
}
@ -198,7 +197,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Assignees
// Only copy those assignees who have access to the task
assignees := []*TaskAssginee{}
err = x.In("task_id", oldTaskIDs).Find(&assignees)
err = s.In("task_id", oldTaskIDs).Find(&assignees)
if err != nil {
return
}
@ -207,7 +206,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
ID: taskMap[a.TaskID],
ListID: ld.List.ID,
}
if err := t.addNewAssigneeByID(a.UserID, ld.List); err != nil {
if err := t.addNewAssigneeByID(s, a.UserID, ld.List); err != nil {
if IsErrUserDoesNotHaveAccessToList(err) {
continue
}
@ -219,14 +218,14 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Comments
comments := []*TaskComment{}
err = x.In("task_id", oldTaskIDs).Find(&comments)
err = s.In("task_id", oldTaskIDs).Find(&comments)
if err != nil {
return
}
for _, c := range comments {
c.ID = 0
c.TaskID = taskMap[c.TaskID]
if _, err := x.Insert(c); err != nil {
if _, err := s.Insert(c); err != nil {
return err
}
}
@ -237,7 +236,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Low-Effort: Only copy those relations which are between tasks in the same list
// because we can do that without a lot of hassle
relations := []*TaskRelation{}
err = x.In("task_id", oldTaskIDs).Find(&relations)
err = s.In("task_id", oldTaskIDs).Find(&relations)
if err != nil {
return
}
@ -249,7 +248,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
r.ID = 0
r.OtherTaskID = otherTaskID
r.TaskID = taskMap[r.TaskID]
if _, err := x.Insert(r); err != nil {
if _, err := s.Insert(r); err != nil {
return err
}
}
@ -276,19 +275,19 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
}
// Get unsplash info if applicable
up, err := GetUnsplashPhotoByFileID(ld.List.BackgroundFileID)
up, err := GetUnsplashPhotoByFileID(s, ld.List.BackgroundFileID)
if err != nil && files.IsErrFileIsNotUnsplashFile(err) {
return err
}
if up != nil {
up.ID = 0
up.FileID = file.ID
if err := up.Save(); err != nil {
if err := up.Save(s); err != nil {
return err
}
}
if err := SetListBackground(ld.List.ID, file); err != nil {
if err := SetListBackground(s, ld.List.ID, file); err != nil {
return err
}
@ -298,14 +297,14 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Rights / Shares
// To keep it simple(r) we will only copy rights which are directly used with the list, no namespace changes.
users := []*ListUser{}
err = x.Where("list_id = ?", ld.ListID).Find(&users)
err = s.Where("list_id = ?", ld.ListID).Find(&users)
if err != nil {
return
}
for _, u := range users {
u.ID = 0
u.ListID = ld.List.ID
if _, err := x.Insert(u); err != nil {
if _, err := s.Insert(u); err != nil {
return err
}
}
@ -313,21 +312,21 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
log.Debugf("Duplicated user shares from list %d into %d", ld.ListID, ld.List.ID)
teams := []*TeamList{}
err = x.Where("list_id = ?", ld.ListID).Find(&teams)
err = s.Where("list_id = ?", ld.ListID).Find(&teams)
if err != nil {
return
}
for _, t := range teams {
t.ID = 0
t.ListID = ld.List.ID
if _, err := x.Insert(t); err != nil {
if _, err := s.Insert(t); err != nil {
return err
}
}
// Generate new link shares if any are available
linkShares := []*LinkSharing{}
err = x.Where("list_id = ?", ld.ListID).Find(&linkShares)
err = s.Where("list_id = ?", ld.ListID).Find(&linkShares)
if err != nil {
return
}
@ -335,7 +334,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
share.ID = 0
share.ListID = ld.List.ID
share.Hash = utils.MakeRandomString(40)
if _, err := x.Insert(share); err != nil {
if _, err := s.Insert(share); err != nil {
return err
}
}

View File

@ -29,6 +29,8 @@ func TestListDuplicate(t *testing.T) {
db.LoadAndAssertFixtures(t)
files.InitTestFileFixtures(t)
s := db.NewSession()
defer s.Close()
u := &user.User{
ID: 1,
@ -38,10 +40,10 @@ func TestListDuplicate(t *testing.T) {
ListID: 1,
NamespaceID: 1,
}
can, err := l.CanCreate(u)
can, err := l.CanCreate(s, u)
assert.NoError(t, err)
assert.True(t, can)
err = l.Create(u)
err = l.Create(s, u)
assert.NoError(t, err)
// To make this test 100% useful, it would need to assert a lot more stuff, but it is good enough for now.
// Also, we're lacking utility functions to do all needed assertions.

View File

@ -20,10 +20,11 @@ import (
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/builder"
"xorm.io/xorm"
)
// CanWrite return whether the user can write on that list or not
func (l *List) CanWrite(a web.Auth) (bool, error) {
func (l *List) CanWrite(s *xorm.Session, a web.Auth) (bool, error) {
// The favorite list can't be edited
if l.ID == FavoritesPseudoList.ID {
@ -31,15 +32,14 @@ func (l *List) CanWrite(a web.Auth) (bool, error) {
}
// Get the list and check the right
originalList := &List{ID: l.ID}
err := originalList.GetSimpleByID()
originalList, err := GetListSimpleByID(s, l.ID)
if err != nil {
return false, err
}
// We put the result of the is archived check in a separate variable to be able to return it later without
// needing to recheck it again
errIsArchived := originalList.CheckIsArchived()
errIsArchived := originalList.CheckIsArchived(s)
var canWrite bool
@ -59,7 +59,7 @@ func (l *List) CanWrite(a web.Auth) (bool, error) {
return canWrite, errIsArchived
}
canWrite, _, err = originalList.checkRight(a, RightWrite, RightAdmin)
canWrite, _, err = originalList.checkRight(s, a, RightWrite, RightAdmin)
if err != nil {
return false, err
}
@ -67,7 +67,7 @@ func (l *List) CanWrite(a web.Auth) (bool, error) {
}
// CanRead checks if a user has read access to a list
func (l *List) CanRead(a web.Auth) (bool, int, error) {
func (l *List) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
// The favorite list needs a special treatment
if l.ID == FavoritesPseudoList.ID {
@ -84,14 +84,18 @@ func (l *List) CanRead(a web.Auth) (bool, int, error) {
// Saved Filter Lists need a special case
if getSavedFilterIDFromListID(l.ID) > 0 {
sf := &SavedFilter{ID: getSavedFilterIDFromListID(l.ID)}
return sf.CanRead(a)
return sf.CanRead(s, a)
}
// Check if the user is either owner or can read
if err := l.GetSimpleByID(); err != nil {
var err error
originalList, err := GetListSimpleByID(s, l.ID)
if err != nil {
return false, 0, err
}
*l = *originalList
// Check if we're dealing with a share auth
shareAuth, ok := a.(*LinkSharing)
if ok {
@ -102,16 +106,16 @@ func (l *List) CanRead(a web.Auth) (bool, int, error) {
if l.isOwner(&user.User{ID: a.GetID()}) {
return true, int(RightAdmin), nil
}
return l.checkRight(a, RightRead, RightWrite, RightAdmin)
return l.checkRight(s, a, RightRead, RightWrite, RightAdmin)
}
// CanUpdate checks if the user can update a list
func (l *List) CanUpdate(a web.Auth) (canUpdate bool, err error) {
func (l *List) CanUpdate(s *xorm.Session, a web.Auth) (canUpdate bool, err error) {
// The favorite list can't be edited
if l.ID == FavoritesPseudoList.ID {
return false, nil
}
canUpdate, err = l.CanWrite(a)
canUpdate, err = l.CanWrite(s, a)
// If the list is archived and the user tries to un-archive it, let the request through
if IsErrListIsArchived(err) && !l.IsArchived {
err = nil
@ -120,26 +124,25 @@ func (l *List) CanUpdate(a web.Auth) (canUpdate bool, err error) {
}
// CanDelete checks if the user can delete a list
func (l *List) CanDelete(a web.Auth) (bool, error) {
return l.IsAdmin(a)
func (l *List) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return l.IsAdmin(s, a)
}
// CanCreate checks if the user can create a list
func (l *List) CanCreate(a web.Auth) (bool, error) {
func (l *List) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
// A user can create a list if they have write access to the namespace
n := &Namespace{ID: l.NamespaceID}
return n.CanWrite(a)
return n.CanWrite(s, a)
}
// IsAdmin returns whether the user has admin rights on the list or not
func (l *List) IsAdmin(a web.Auth) (bool, error) {
func (l *List) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
// The favorite list can't be edited
if l.ID == FavoritesPseudoList.ID {
return false, nil
}
originalList := &List{ID: l.ID}
err := originalList.GetSimpleByID()
originalList, err := GetListSimpleByID(s, l.ID)
if err != nil {
return false, err
}
@ -156,7 +159,7 @@ func (l *List) IsAdmin(a web.Auth) (bool, error) {
if originalList.isOwner(&user.User{ID: a.GetID()}) {
return true, nil
}
is, _, err := originalList.checkRight(a, RightAdmin)
is, _, err := originalList.checkRight(s, a, RightAdmin)
return is, err
}
@ -166,7 +169,7 @@ func (l *List) isOwner(u *user.User) bool {
}
// Checks n different rights for any given user
func (l *List) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
func (l *List) checkRight(s *xorm.Session, a web.Auth, rights ...Right) (bool, int, error) {
/*
The following loop creates an sql condition like this one:
@ -218,7 +221,7 @@ func (l *List) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
r := &allListRights{}
var maxRight = 0
exists, err := x.
exists, err := s.
Table("list").
Alias("l").
// User stuff

View File

@ -20,6 +20,7 @@ import (
"time"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// TeamList defines the relation between a team and a list
@ -68,7 +69,7 @@ type TeamWithRight struct {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/teams [put]
func (tl *TeamList) Create(a web.Auth) (err error) {
func (tl *TeamList) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the rights are valid
if err = tl.Right.isValid(); err != nil {
@ -76,19 +77,19 @@ func (tl *TeamList) Create(a web.Auth) (err error) {
}
// Check if the team exists
_, err = GetTeamByID(tl.TeamID)
_, err = GetTeamByID(s, tl.TeamID)
if err != nil {
return
}
// Check if the list exists
l := &List{ID: tl.ListID}
if err := l.GetSimpleByID(); err != nil {
l, err := GetListSimpleByID(s, tl.ListID)
if err != nil {
return err
}
// Check if the team is already on the list
exists, err := x.Where("team_id = ?", tl.TeamID).
exists, err := s.Where("team_id = ?", tl.TeamID).
And("list_id = ?", tl.ListID).
Get(&TeamList{})
if err != nil {
@ -99,12 +100,12 @@ func (tl *TeamList) Create(a web.Auth) (err error) {
}
// Insert the new team
_, err = x.Insert(tl)
_, err = s.Insert(tl)
if err != nil {
return err
}
err = updateListLastUpdated(l)
err = updateListLastUpdated(s, l)
return
}
@ -121,16 +122,17 @@ func (tl *TeamList) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "Team or list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/teams/{teamID} [delete]
func (tl *TeamList) Delete() (err error) {
func (tl *TeamList) Delete(s *xorm.Session) (err error) {
// Check if the team exists
_, err = GetTeamByID(tl.TeamID)
_, err = GetTeamByID(s, tl.TeamID)
if err != nil {
return
}
// Check if the team has access to the list
has, err := x.Where("team_id = ? AND list_id = ?", tl.TeamID, tl.ListID).
has, err := s.
Where("team_id = ? AND list_id = ?", tl.TeamID, tl.ListID).
Get(&TeamList{})
if err != nil {
return
@ -140,14 +142,14 @@ func (tl *TeamList) Delete() (err error) {
}
// Delete the relation
_, err = x.Where("team_id = ?", tl.TeamID).
_, err = s.Where("team_id = ?", tl.TeamID).
And("list_id = ?", tl.ListID).
Delete(TeamList{})
if err != nil {
return err
}
err = updateListLastUpdated(&List{ID: tl.ListID})
err = updateListLastUpdated(s, &List{ID: tl.ListID})
return
}
@ -166,10 +168,10 @@ func (tl *TeamList) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the list."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/teams [get]
func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
func (tl *TeamList) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
// Check if the user can read the namespace
l := &List{ID: tl.ListID}
canRead, _, err := l.CanRead(a)
canRead, _, err := l.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -181,7 +183,7 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get the teams
all := []*TeamWithRight{}
query := x.
query := s.
Table("teams").
Join("INNER", "team_list", "team_id = teams.id").
Where("team_list.list_id = ?", tl.ListID).
@ -199,12 +201,12 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r
teams = append(teams, &t.Team)
}
err = addMoreInfoToTeams(teams)
err = addMoreInfoToTeams(s, teams)
if err != nil {
return
}
totalItems, err = x.
totalItems, err = s.
Table("teams").
Join("INNER", "team_list", "team_id = teams.id").
Where("team_list.list_id = ?", tl.ListID).
@ -232,14 +234,14 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "Team or list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/teams/{teamID} [post]
func (tl *TeamList) Update() (err error) {
func (tl *TeamList) Update(s *xorm.Session) (err error) {
// Check if the right is valid
if err := tl.Right.isValid(); err != nil {
return err
}
_, err = x.
_, err = s.
Where("list_id = ? AND team_id = ?", tl.ListID, tl.TeamID).
Cols("right").
Update(tl)
@ -247,6 +249,6 @@ func (tl *TeamList) Update() (err error) {
return err
}
err = updateListLastUpdated(&List{ID: tl.ListID})
err = updateListLastUpdated(s, &List{ID: tl.ListID})
return
}

View File

@ -18,29 +18,30 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if the user can create a team <-> list relation
func (tl *TeamList) CanCreate(a web.Auth) (bool, error) {
return tl.canDoTeamList(a)
func (tl *TeamList) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return tl.canDoTeamList(s, a)
}
// CanDelete checks if the user can delete a team <-> list relation
func (tl *TeamList) CanDelete(a web.Auth) (bool, error) {
return tl.canDoTeamList(a)
func (tl *TeamList) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return tl.canDoTeamList(s, a)
}
// CanUpdate checks if the user can update a team <-> list relation
func (tl *TeamList) CanUpdate(a web.Auth) (bool, error) {
return tl.canDoTeamList(a)
func (tl *TeamList) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return tl.canDoTeamList(s, a)
}
func (tl *TeamList) canDoTeamList(a web.Auth) (bool, error) {
func (tl *TeamList) canDoTeamList(s *xorm.Session, a web.Auth) (bool, error) {
// Link shares aren't allowed to do anything
if _, is := a.(*LinkSharing); is {
return false, nil
}
l := List{ID: tl.ListID}
return l.IsAdmin(a)
return l.IsAdmin(s, a)
}

View File

@ -37,20 +37,24 @@ func TestTeamList_ReadAll(t *testing.T) {
ListID: 3,
}
db.LoadAndAssertFixtures(t)
teams, _, _, err := tl.ReadAll(u, "", 1, 50)
s := db.NewSession()
teams, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice)
s := reflect.ValueOf(teams)
assert.Equal(t, s.Len(), 1)
ts := reflect.ValueOf(teams)
assert.Equal(t, ts.Len(), 1)
_ = s.Close()
})
t.Run("nonexistant list", func(t *testing.T) {
tl := TeamList{
ListID: 99999,
}
db.LoadAndAssertFixtures(t)
_, _, _, err := tl.ReadAll(u, "", 1, 50)
s := db.NewSession()
_, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err))
_ = s.Close()
})
t.Run("namespace owner", func(t *testing.T) {
tl := TeamList{
@ -59,8 +63,10 @@ func TestTeamList_ReadAll(t *testing.T) {
Right: RightAdmin,
}
db.LoadAndAssertFixtures(t)
_, _, _, err := tl.ReadAll(u, "", 1, 50)
s := db.NewSession()
_, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err)
_ = s.Close()
})
t.Run("no access", func(t *testing.T) {
tl := TeamList{
@ -69,9 +75,11 @@ func TestTeamList_ReadAll(t *testing.T) {
Right: RightAdmin,
}
db.LoadAndAssertFixtures(t)
_, _, _, err := tl.ReadAll(u, "", 1, 50)
s := db.NewSession()
_, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.Error(t, err)
assert.True(t, IsErrNeedToHaveListReadAccess(err))
_ = s.Close()
})
}
@ -79,14 +87,17 @@ func TestTeamList_Create(t *testing.T) {
u := &user.User{ID: 1}
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 1,
ListID: 1,
Right: RightAdmin,
}
allowed, _ := tl.CanCreate(u)
allowed, _ := tl.CanCreate(s, u)
assert.True(t, allowed)
err := tl.Create(u)
err := tl.Create(s, u)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_list", map[string]interface{}{
"team_id": 1,
@ -96,56 +107,67 @@ func TestTeamList_Create(t *testing.T) {
})
t.Run("team already has access", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 1,
ListID: 3,
Right: RightAdmin,
}
err := tl.Create(u)
err := tl.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrTeamAlreadyHasAccess(err))
_ = s.Close()
})
t.Run("wrong rights", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 1,
ListID: 1,
Right: RightUnknown,
}
err := tl.Create(u)
err := tl.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrInvalidRight(err))
_ = s.Close()
})
t.Run("nonexistant team", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 9999,
ListID: 1,
}
err := tl.Create(u)
err := tl.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
})
t.Run("nonexistant list", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 1,
ListID: 9999,
}
err := tl.Create(u)
err := tl.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err))
_ = s.Close()
})
}
func TestTeamList_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 1,
ListID: 3,
}
err := tl.Delete()
err := tl.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "team_list", map[string]interface{}{
"team_id": 1,
@ -154,23 +176,27 @@ func TestTeamList_Delete(t *testing.T) {
})
t.Run("nonexistant team", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 9999,
ListID: 1,
}
err := tl.Delete()
err := tl.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
})
t.Run("nonexistant list", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{
TeamID: 1,
ListID: 9999,
}
err := tl.Delete()
err := tl.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotHaveAccessToList(err))
_ = s.Close()
})
}
@ -229,6 +255,7 @@ func TestTeamList_Update(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := &TeamList{
ID: tt.fields.ID,
@ -240,13 +267,15 @@ func TestTeamList_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := tl.Update()
err := tl.Update(s)
if (err != nil) != tt.wantErr {
t.Errorf("TeamList.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("TeamList.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "team_list", map[string]interface{}{
"list_id": tt.fields.ListID,

View File

@ -35,12 +35,15 @@ func TestList_CreateOrUpdate(t *testing.T) {
t.Run("create", func(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
Title: "test",
Description: "Lorem Ipsum",
NamespaceID: 1,
}
err := list.Create(usr)
err := list.Create(s, usr)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "list", map[string]interface{}{
"id": list.ID,
@ -51,49 +54,56 @@ func TestList_CreateOrUpdate(t *testing.T) {
})
t.Run("nonexistant namespace", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
Title: "test",
Description: "Lorem Ipsum",
NamespaceID: 999999,
}
err := list.Create(usr)
err := list.Create(s, usr)
assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
})
t.Run("nonexistant owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
usr := &user.User{ID: 9482385}
list := List{
Title: "test",
Description: "Lorem Ipsum",
NamespaceID: 1,
}
err := list.Create(usr)
err := list.Create(s, usr)
assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
})
t.Run("existing identifier", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
Title: "test",
Description: "Lorem Ipsum",
Identifier: "test1",
NamespaceID: 1,
}
err := list.Create(usr)
err := list.Create(s, usr)
assert.Error(t, err)
assert.True(t, IsErrListIdentifierIsNotUnique(err))
_ = s.Close()
})
t.Run("non ascii characters", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
Title: "приффки фсем",
Description: "Lorem Ipsum",
NamespaceID: 1,
}
err := list.Create(usr)
err := list.Create(s, usr)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "list", map[string]interface{}{
"id": list.ID,
@ -107,6 +117,7 @@ func TestList_CreateOrUpdate(t *testing.T) {
t.Run("update", func(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
ID: 1,
Title: "test",
@ -114,7 +125,9 @@ func TestList_CreateOrUpdate(t *testing.T) {
NamespaceID: 1,
}
list.Description = "Lorem Ipsum dolor sit amet."
err := list.Update()
err := list.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "list", map[string]interface{}{
"id": list.ID,
@ -125,37 +138,43 @@ func TestList_CreateOrUpdate(t *testing.T) {
})
t.Run("nonexistant", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
ID: 99999999,
Title: "test",
}
err := list.Update()
err := list.Update(s)
assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err))
_ = s.Close()
})
t.Run("existing identifier", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
Title: "test",
Description: "Lorem Ipsum",
Identifier: "test1",
NamespaceID: 1,
}
err := list.Create(usr)
err := list.Create(s, usr)
assert.Error(t, err)
assert.True(t, IsErrListIdentifierIsNotUnique(err))
_ = s.Close()
})
})
}
func TestList_Delete(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{
ID: 1,
}
err := list.Delete()
err := list.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "list", map[string]interface{}{
"id": 1,
@ -165,30 +184,34 @@ func TestList_Delete(t *testing.T) {
func TestList_ReadAll(t *testing.T) {
t.Run("all in namespace", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
// Get all lists for our namespace
lists, err := GetListsByNamespaceID(1, &user.User{})
lists, err := GetListsByNamespaceID(s, 1, &user.User{})
assert.NoError(t, err)
assert.Equal(t, len(lists), 2)
_ = s.Close()
})
t.Run("all lists for user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
u := &user.User{ID: 1}
list := List{}
lists3, _, _, err := list.ReadAll(u, "", 1, 50)
lists3, _, _, err := list.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(lists3).Kind(), reflect.Slice)
s := reflect.ValueOf(lists3)
assert.Equal(t, 16, s.Len())
ls := reflect.ValueOf(lists3)
assert.Equal(t, 16, ls.Len())
_ = s.Close()
})
t.Run("lists for nonexistant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
usr := &user.User{ID: 999999}
list := List{}
_, _, _, err := list.ReadAll(usr, "", 1, 50)
_, _, _, err := list.ReadAll(s, usr, "", 1, 50)
assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
})
}

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// ListUser represents a list <-> user relation
@ -71,7 +72,7 @@ type UserWithRight struct {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/users [put]
func (lu *ListUser) Create(a web.Auth) (err error) {
func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the right is valid
if err := lu.Right.isValid(); err != nil {
@ -79,17 +80,17 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
}
// Check if the list exists
l := &List{ID: lu.ListID}
if err = l.GetSimpleByID(); err != nil {
l, err := GetListSimpleByID(s, lu.ListID)
if err != nil {
return
}
// Check if the user exists
user, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(s, lu.Username)
if err != nil {
return err
}
lu.UserID = user.ID
lu.UserID = u.ID
// Check if the user already has access or is owner of that list
// We explicitly DONT check for teams here
@ -97,7 +98,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
return ErrUserAlreadyHasAccess{UserID: lu.UserID, ListID: lu.ListID}
}
exist, err := x.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
exist, err := s.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
if err != nil {
return
}
@ -106,12 +107,12 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
}
// Insert user <-> list relation
_, err = x.Insert(lu)
_, err = s.Insert(lu)
if err != nil {
return err
}
err = updateListLastUpdated(l)
err = updateListLastUpdated(s, l)
return
}
@ -128,17 +129,18 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "user or list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/users/{userID} [delete]
func (lu *ListUser) Delete() (err error) {
func (lu *ListUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists
user, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(s, lu.Username)
if err != nil {
return
}
lu.UserID = user.ID
lu.UserID = u.ID
// Check if the user has access to the list
has, err := x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
has, err := s.
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
Get(&ListUser{})
if err != nil {
return
@ -147,13 +149,14 @@ func (lu *ListUser) Delete() (err error) {
return ErrUserDoesNotHaveAccessToList{ListID: lu.ListID, UserID: lu.UserID}
}
_, err = x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
_, err = s.
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
Delete(&ListUser{})
if err != nil {
return err
}
err = updateListLastUpdated(&List{ID: lu.ListID})
err = updateListLastUpdated(s, &List{ID: lu.ListID})
return
}
@ -172,10 +175,10 @@ func (lu *ListUser) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the list."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/users [get]
func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (lu *ListUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the list
l := &List{ID: lu.ListID}
canRead, _, err := l.CanRead(a)
canRead, _, err := l.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -187,7 +190,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all users
all := []*UserWithRight{}
query := x.
query := s.
Join("INNER", "users_list", "user_id = users.id").
Where("users_list.list_id = ?", lu.ListID).
Where("users.username LIKE ?", "%"+search+"%")
@ -204,7 +207,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
u.Email = ""
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Join("INNER", "users_list", "user_id = users.id").
Where("users_list.list_id = ?", lu.ListID).
Where("users.username LIKE ?", "%"+search+"%").
@ -228,7 +231,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "User or list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/users/{userID} [post]
func (lu *ListUser) Update() (err error) {
func (lu *ListUser) Update(s *xorm.Session) (err error) {
// Check if the right is valid
if err := lu.Right.isValid(); err != nil {
@ -236,13 +239,13 @@ func (lu *ListUser) Update() (err error) {
}
// Check if the user exists
u, err := user.GetUserByUsername(lu.Username)
u, err := user.GetUserByUsername(s, lu.Username)
if err != nil {
return err
}
lu.UserID = u.ID
_, err = x.
_, err = s.
Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).
Cols("right").
Update(lu)
@ -250,6 +253,6 @@ func (lu *ListUser) Update() (err error) {
return err
}
err = updateListLastUpdated(&List{ID: lu.ListID})
err = updateListLastUpdated(s, &List{ID: lu.ListID})
return
}

View File

@ -18,24 +18,25 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if the user can create a new user <-> list relation
func (lu *ListUser) CanCreate(a web.Auth) (bool, error) {
return lu.canDoListUser(a)
func (lu *ListUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(s, a)
}
// CanDelete checks if the user can delete a user <-> list relation
func (lu *ListUser) CanDelete(a web.Auth) (bool, error) {
return lu.canDoListUser(a)
func (lu *ListUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(s, a)
}
// CanUpdate checks if the user can update a user <-> list relation
func (lu *ListUser) CanUpdate(a web.Auth) (bool, error) {
return lu.canDoListUser(a)
func (lu *ListUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(s, a)
}
func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) {
func (lu *ListUser) canDoListUser(s *xorm.Session, a web.Auth) (bool, error) {
// Link shares aren't allowed to do anything
if _, is := a.(*LinkSharing); is {
return false, nil
@ -43,5 +44,5 @@ func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) {
// Get the list and check if the user has write access on it
l := List{ID: lu.ListID}
return l.IsAdmin(a)
return l.IsAdmin(s, a)
}

View File

@ -80,6 +80,7 @@ func TestListUser_CanDoSomething(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
lu := &ListUser{
ID: tt.fields.ID,
@ -91,15 +92,16 @@ func TestListUser_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
if got, _ := lu.CanCreate(tt.args.a); got != tt.want["CanCreate"] {
if got, _ := lu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("ListUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
}
if got, _ := lu.CanDelete(tt.args.a); got != tt.want["CanDelete"] {
if got, _ := lu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("ListUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
}
if got, _ := lu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] {
if got, _ := lu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("ListUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
}
_ = s.Close()
})
}
}

View File

@ -24,9 +24,9 @@ import (
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/user"
"gopkg.in/d4l3k/messagediff.v1"
"code.vikunja.io/web"
"github.com/stretchr/testify/assert"
"gopkg.in/d4l3k/messagediff.v1"
)
func TestListUser_Create(t *testing.T) {
@ -108,6 +108,7 @@ func TestListUser_Create(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
ul := &ListUser{
ID: tt.fields.ID,
@ -120,13 +121,17 @@ func TestListUser_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := ul.Create(tt.args.a)
err := ul.Create(s, tt.args.a)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Create() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "users_list", map[string]interface{}{
"user_id": ul.UserID,
@ -212,6 +217,7 @@ func TestListUser_ReadAll(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
ul := &ListUser{
ID: tt.fields.ID,
@ -223,7 +229,7 @@ func TestListUser_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
got, _, _, err := ul.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50)
got, _, _, err := ul.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
}
@ -233,6 +239,7 @@ func TestListUser_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(got, tt.want); !equal {
t.Errorf("ListUser.ReadAll() = %v, want %v, diff: %v", got, tt.want, diff)
}
_ = s.Close()
})
}
}
@ -292,6 +299,7 @@ func TestListUser_Update(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
lu := &ListUser{
ID: tt.fields.ID,
@ -303,13 +311,17 @@ func TestListUser_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := lu.Update()
err := lu.Update(s)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "users_list", map[string]interface{}{
"list_id": tt.fields.ListID,
@ -369,6 +381,7 @@ func TestListUser_Delete(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
lu := &ListUser{
ID: tt.fields.ID,
@ -380,13 +393,17 @@ func TestListUser_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := lu.Delete()
err := lu.Delete(s)
if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertMissing(t, "users_list", map[string]interface{}{
"user_id": tt.fields.UserID,

View File

@ -23,12 +23,11 @@ import (
"time"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"github.com/imdario/mergo"
"xorm.io/builder"
"xorm.io/xorm"
)
// Namespace holds informations about a namespace
@ -95,55 +94,48 @@ func (Namespace) TableName() string {
}
// GetSimpleByID gets a namespace without things like the owner, it more or less only checks if it exists.
func (n *Namespace) GetSimpleByID() (err error) {
if n.ID == 0 {
return ErrNamespaceDoesNotExist{ID: n.ID}
func getNamespaceSimpleByID(s *xorm.Session, id int64) (namespace *Namespace, err error) {
if id == 0 {
return nil, ErrNamespaceDoesNotExist{ID: id}
}
// Get the namesapce with shared lists
if n.ID == -1 {
*n = SharedListsPseudoNamespace
return
if id == -1 {
return &SharedListsPseudoNamespace, nil
}
if n.ID == FavoritesPseudoNamespace.ID {
*n = FavoritesPseudoNamespace
return
if id == FavoritesPseudoNamespace.ID {
return &FavoritesPseudoNamespace, nil
}
namespaceFromDB := &Namespace{}
exists, err := x.Where("id = ?", n.ID).Get(namespaceFromDB)
namespace = &Namespace{}
exists, err := s.Where("id = ?", id).Get(namespace)
if err != nil {
return
}
if !exists {
return ErrNamespaceDoesNotExist{ID: n.ID}
return nil, ErrNamespaceDoesNotExist{ID: id}
}
// We don't want to override the provided user struct because this would break updating, so we have to merge it
if err := mergo.Merge(namespaceFromDB, n, mergo.WithOverride); err != nil {
return err
}
*n = *namespaceFromDB
return
}
// GetNamespaceByID returns a namespace object by its ID
func GetNamespaceByID(id int64) (namespace Namespace, err error) {
namespace = Namespace{ID: id}
err = namespace.GetSimpleByID()
func GetNamespaceByID(s *xorm.Session, id int64) (namespace *Namespace, err error) {
namespace, err = getNamespaceSimpleByID(s, id)
if err != nil {
return
}
// Get the namespace Owner
namespace.Owner, err = user.GetUserByID(namespace.OwnerID)
namespace.Owner, err = user.GetUserByID(s, namespace.OwnerID)
return
}
// CheckIsArchived returns an ErrNamespaceIsArchived if the namepace is archived.
func (n *Namespace) CheckIsArchived() error {
exists, err := x.
func (n *Namespace) CheckIsArchived(s *xorm.Session) error {
exists, err := s.
Where("id = ? AND is_archived = true", n.ID).
Exist(&Namespace{})
if err != nil {
@ -167,8 +159,12 @@ func (n *Namespace) CheckIsArchived() error {
// @Failure 403 {object} web.HTTPError "The user does not have access to that namespace."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id} [get]
func (n *Namespace) ReadOne() (err error) {
*n, err = GetNamespaceByID(n.ID)
func (n *Namespace) ReadOne(s *xorm.Session) (err error) {
nn, err := GetNamespaceByID(s, n.ID)
if err != nil {
return err
}
*n = *nn
return
}
@ -207,7 +203,7 @@ func makeNamespaceSliceFromMap(namespaces map[int64]*NamespaceWithLists, userMap
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces [get]
//nolint:gocyclo
func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (n *Namespace) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
if _, is := a.(*LinkSharing); is {
return nil, 0, 0, ErrGenericForbidden{}
}
@ -249,7 +245,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
}
limit, start := getLimitFromPageIndex(page, perPage)
query := x.Select("namespaces.*").
query := s.Select("namespaces.*").
Table("namespaces").
Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id").
Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id").
@ -268,7 +264,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
return nil, 0, 0, err
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Table("namespaces").
Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id").
Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id").
@ -294,7 +290,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all owners
userMap := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&userMap)
err = s.In("id", userIDs).Find(&userMap)
if err != nil {
return nil, 0, 0, err
}
@ -306,7 +302,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all lists
lists := []*List{}
listQuery := x.
listQuery := s.
In("namespace_id", namespaceids)
if !n.IsArchived {
@ -330,7 +326,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all lists individually shared with our user (not via a namespace)
individualLists := []*List{}
iListQuery := x.Select("l.*").
iListQuery := s.Select("l.*").
Table("list").
Alias("l").
Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id").
@ -360,7 +356,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
}
// More details for the lists
err = AddListDetails(lists)
err = addListDetails(s, lists)
if err != nil {
return nil, 0, 0, err
}
@ -386,7 +382,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Check if we have any favorites or favorited lists and remove the favorites namespace from the list if not
var favoriteCount int64
favoriteCount, err = x.
favoriteCount, err = s.
Join("INNER", "list", "tasks.list_id = list.id").
Join("INNER", "namespaces", "list.namespace_id = namespaces.id").
Where(builder.And(builder.Eq{"tasks.is_favorite": true}, builder.In("namespaces.id", namespaceids))).
@ -413,7 +409,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
/////////////////
// Saved Filters
savedFilters, err := getSavedFiltersForUser(a)
savedFilters, err := getSavedFiltersForUser(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -457,7 +453,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces [put]
func (n *Namespace) Create(a web.Auth) (err error) {
func (n *Namespace) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if we have at least a name
if n.Title == "" {
return ErrNamespaceNameCannotBeEmpty{NamespaceID: 0, UserID: a.GetID()}
@ -465,14 +461,14 @@ func (n *Namespace) Create(a web.Auth) (err error) {
n.ID = 0 // This would otherwise prevent the creation of new lists after one was created
// Check if the User exists
n.Owner, err = user.GetUserByID(a.GetID())
n.Owner, err = user.GetUserByID(s, a.GetID())
if err != nil {
return
}
n.OwnerID = n.Owner.ID
// Insert
if _, err = x.Insert(n); err != nil {
if _, err = s.Insert(n); err != nil {
return err
}
@ -482,12 +478,12 @@ func (n *Namespace) Create(a web.Auth) (err error) {
// CreateNewNamespaceForUser creates a new namespace for a user. To prevent import cycles, we can't do that
// directly in the user.Create function.
func CreateNewNamespaceForUser(user *user.User) (err error) {
func CreateNewNamespaceForUser(s *xorm.Session, user *user.User) (err error) {
newN := &Namespace{
Title: user.Username,
Description: user.Username + "'s namespace.",
}
return newN.Create(user)
return newN.Create(s, user)
}
// Delete deletes a namespace
@ -502,22 +498,22 @@ func CreateNewNamespaceForUser(user *user.User) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id} [delete]
func (n *Namespace) Delete() (err error) {
func (n *Namespace) Delete(s *xorm.Session) (err error) {
// Check if the namespace exists
_, err = GetNamespaceByID(n.ID)
_, err = GetNamespaceByID(s, n.ID)
if err != nil {
return
}
// Delete the namespace
_, err = x.ID(n.ID).Delete(&Namespace{})
_, err = s.ID(n.ID).Delete(&Namespace{})
if err != nil {
return
}
// Delete all lists with their tasks
lists, err := GetListsByNamespaceID(n.ID, &user.User{})
lists, err := GetListsByNamespaceID(s, n.ID, &user.User{})
if err != nil {
return
}
@ -530,13 +526,13 @@ func (n *Namespace) Delete() (err error) {
}
// Delete tasks
_, err = x.In("list_id", listIDs).Delete(&Task{})
_, err = s.In("list_id", listIDs).Delete(&Task{})
if err != nil {
return
}
// Delete the lists
_, err = x.In("id", listIDs).Delete(&List{})
_, err = s.In("id", listIDs).Delete(&List{})
if err != nil {
return
}
@ -560,14 +556,14 @@ func (n *Namespace) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespace/{id} [post]
func (n *Namespace) Update() (err error) {
func (n *Namespace) Update(s *xorm.Session) (err error) {
// Check if we have at least a name
if n.Title == "" {
return ErrNamespaceNameCannotBeEmpty{NamespaceID: n.ID}
}
// Check if the namespace exists
currentNamespace, err := GetNamespaceByID(n.ID)
currentNamespace, err := GetNamespaceByID(s, n.ID)
if err != nil {
return
}
@ -581,7 +577,7 @@ func (n *Namespace) Update() (err error) {
if n.Owner != nil {
n.OwnerID = n.Owner.ID
if currentNamespace.OwnerID != n.OwnerID {
n.Owner, err = user.GetUserByID(n.OwnerID)
n.Owner, err = user.GetUserByID(s, n.OwnerID)
if err != nil {
return
}
@ -599,7 +595,7 @@ func (n *Namespace) Update() (err error) {
}
// Do the actual update
_, err = x.
_, err = s.
ID(currentNamespace.ID).
Cols(colsToUpdate...).
Update(n)

View File

@ -19,37 +19,38 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/builder"
"xorm.io/xorm"
)
// CanWrite checks if a user has write access to a namespace
func (n *Namespace) CanWrite(a web.Auth) (bool, error) {
can, _, err := n.checkRight(a, RightWrite, RightAdmin)
func (n *Namespace) CanWrite(s *xorm.Session, a web.Auth) (bool, error) {
can, _, err := n.checkRight(s, a, RightWrite, RightAdmin)
return can, err
}
// IsAdmin returns true or false if the user is admin on that namespace or not
func (n *Namespace) IsAdmin(a web.Auth) (bool, error) {
is, _, err := n.checkRight(a, RightAdmin)
func (n *Namespace) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
is, _, err := n.checkRight(s, a, RightAdmin)
return is, err
}
// CanRead checks if a user has read access to that namespace
func (n *Namespace) CanRead(a web.Auth) (bool, int, error) {
return n.checkRight(a, RightRead, RightWrite, RightAdmin)
func (n *Namespace) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
return n.checkRight(s, a, RightRead, RightWrite, RightAdmin)
}
// CanUpdate checks if the user can update the namespace
func (n *Namespace) CanUpdate(a web.Auth) (bool, error) {
return n.IsAdmin(a)
func (n *Namespace) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return n.IsAdmin(s, a)
}
// CanDelete checks if the user can delete a namespace
func (n *Namespace) CanDelete(a web.Auth) (bool, error) {
return n.IsAdmin(a)
func (n *Namespace) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return n.IsAdmin(s, a)
}
// CanCreate checks if the user can create a new namespace
func (n *Namespace) CanCreate(a web.Auth) (bool, error) {
func (n *Namespace) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is {
return false, nil
}
@ -58,7 +59,7 @@ func (n *Namespace) CanCreate(a web.Auth) (bool, error) {
return true, nil
}
func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
func (n *Namespace) checkRight(s *xorm.Session, a web.Auth, rights ...Right) (bool, int, error) {
// If the auth is a link share, don't do anything
if _, is := a.(*LinkSharing); is {
@ -66,13 +67,12 @@ func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
}
// Get the namespace and check the right
nn := &Namespace{ID: n.ID}
err := nn.GetSimpleByID()
nn, err := getNamespaceSimpleByID(s, n.ID)
if err != nil {
return false, 0, err
}
if a.GetID() == n.OwnerID {
if a.GetID() == nn.OwnerID {
return true, int(RightAdmin), nil
}
@ -113,7 +113,8 @@ func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
var maxRights = 0
r := &allRights{}
exists, err := x.Select("*").
exists, err := s.
Select("*").
Table("namespaces").
// User stuff
Join("LEFT", "users_namespace", "users_namespace.namespace_id = namespaces.id").

View File

@ -20,6 +20,7 @@ import (
"time"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// TeamNamespace defines the relationship between a Team and a Namespace
@ -62,7 +63,7 @@ func (TeamNamespace) TableName() string {
// @Failure 403 {object} web.HTTPError "The team does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/teams [put]
func (tn *TeamNamespace) Create(a web.Auth) (err error) {
func (tn *TeamNamespace) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the rights are valid
if err = tn.Right.isValid(); err != nil {
@ -70,19 +71,20 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) {
}
// Check if the team exists
_, err = GetTeamByID(tn.TeamID)
_, err = GetTeamByID(s, tn.TeamID)
if err != nil {
return
}
// Check if the namespace exists
_, err = GetNamespaceByID(tn.NamespaceID)
_, err = GetNamespaceByID(s, tn.NamespaceID)
if err != nil {
return
}
// Check if the team already has access to the namespace
exists, err := x.Where("team_id = ?", tn.TeamID).
exists, err := s.
Where("team_id = ?", tn.TeamID).
And("namespace_id = ?", tn.NamespaceID).
Get(&TeamNamespace{})
if err != nil {
@ -93,7 +95,7 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) {
}
// Insert the new team
_, err = x.Insert(tn)
_, err = s.Insert(tn)
return
}
@ -110,16 +112,17 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "team or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/teams/{teamID} [delete]
func (tn *TeamNamespace) Delete() (err error) {
func (tn *TeamNamespace) Delete(s *xorm.Session) (err error) {
// Check if the team exists
_, err = GetTeamByID(tn.TeamID)
_, err = GetTeamByID(s, tn.TeamID)
if err != nil {
return
}
// Check if the team has access to the namespace
has, err := x.Where("team_id = ? AND namespace_id = ?", tn.TeamID, tn.NamespaceID).
has, err := s.
Where("team_id = ? AND namespace_id = ?", tn.TeamID, tn.NamespaceID).
Get(&TeamNamespace{})
if err != nil {
return
@ -129,7 +132,8 @@ func (tn *TeamNamespace) Delete() (err error) {
}
// Delete the relation
_, err = x.Where("team_id = ?", tn.TeamID).
_, err = s.
Where("team_id = ?", tn.TeamID).
And("namespace_id = ?", tn.NamespaceID).
Delete(TeamNamespace{})
@ -151,10 +155,10 @@ func (tn *TeamNamespace) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the namespace."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/teams [get]
func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (tn *TeamNamespace) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user can read the namespace
n := Namespace{ID: tn.NamespaceID}
canRead, _, err := n.CanRead(a)
canRead, _, err := n.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -167,7 +171,8 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in
limit, start := getLimitFromPageIndex(page, perPage)
query := x.Table("teams").
query := s.
Table("teams").
Join("INNER", "team_namespaces", "team_id = teams.id").
Where("team_namespaces.namespace_id = ?", tn.NamespaceID).
Where("teams.name LIKE ?", "%"+search+"%")
@ -184,12 +189,13 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in
teams = append(teams, &t.Team)
}
err = addMoreInfoToTeams(teams)
err = addMoreInfoToTeams(s, teams)
if err != nil {
return
}
numberOfTotalItems, err = x.Table("teams").
numberOfTotalItems, err = s.
Table("teams").
Join("INNER", "team_namespaces", "team_id = teams.id").
Where("team_namespaces.namespace_id = ?", tn.NamespaceID).
Where("teams.name LIKE ?", "%"+search+"%").
@ -213,14 +219,14 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in
// @Failure 404 {object} web.HTTPError "Team or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/teams/{teamID} [post]
func (tn *TeamNamespace) Update() (err error) {
func (tn *TeamNamespace) Update(s *xorm.Session) (err error) {
// Check if the right is valid
if err := tn.Right.isValid(); err != nil {
return err
}
_, err = x.
_, err = s.
Where("namespace_id = ? AND team_id = ?", tn.NamespaceID, tn.TeamID).
Cols("right").
Update(tn)

View File

@ -18,22 +18,23 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if one can create a new team <-> namespace relation
func (tn *TeamNamespace) CanCreate(a web.Auth) (bool, error) {
func (tn *TeamNamespace) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: tn.NamespaceID}
return n.IsAdmin(a)
return n.IsAdmin(s, a)
}
// CanDelete checks if a user can remove a team from a namespace. Only namespace admins can do that.
func (tn *TeamNamespace) CanDelete(a web.Auth) (bool, error) {
func (tn *TeamNamespace) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: tn.NamespaceID}
return n.IsAdmin(a)
return n.IsAdmin(s, a)
}
// CanUpdate checks if a user can update a team from a Only namespace admins can do that.
func (tn *TeamNamespace) CanUpdate(a web.Auth) (bool, error) {
func (tn *TeamNamespace) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: tn.NamespaceID}
return n.IsAdmin(a)
return n.IsAdmin(s, a)
}

View File

@ -80,6 +80,7 @@ func TestTeamNamespace_CanDoSomething(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tn := &TeamNamespace{
ID: tt.fields.ID,
@ -91,15 +92,16 @@ func TestTeamNamespace_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
if got, _ := tn.CanCreate(tt.args.a); got != tt.want["CanCreate"] {
if got, _ := tn.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("TeamNamespace.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
}
if got, _ := tn.CanDelete(tt.args.a); got != tt.want["CanDelete"] {
if got, _ := tn.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("TeamNamespace.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
}
if got, _ := tn.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] {
if got, _ := tn.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("TeamNamespace.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
}
_ = s.Close()
})
}
}

View File

@ -36,29 +36,35 @@ func TestTeamNamespace_ReadAll(t *testing.T) {
NamespaceID: 3,
}
db.LoadAndAssertFixtures(t)
teams, _, _, err := tn.ReadAll(u, "", 1, 50)
s := db.NewSession()
teams, _, _, err := tn.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice)
s := reflect.ValueOf(teams)
assert.Equal(t, s.Len(), 2)
ts := reflect.ValueOf(teams)
assert.Equal(t, ts.Len(), 2)
_ = s.Close()
})
t.Run("nonexistant namespace", func(t *testing.T) {
tn := TeamNamespace{
NamespaceID: 9999,
}
db.LoadAndAssertFixtures(t)
_, _, _, err := tn.ReadAll(u, "", 1, 50)
s := db.NewSession()
_, _, _, err := tn.ReadAll(s, u, "", 1, 50)
assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
})
t.Run("no right for namespace", func(t *testing.T) {
tn := TeamNamespace{
NamespaceID: 17,
}
db.LoadAndAssertFixtures(t)
_, _, _, err := tn.ReadAll(u, "", 1, 50)
s := db.NewSession()
_, _, _, err := tn.ReadAll(s, u, "", 1, 50)
assert.Error(t, err)
assert.True(t, IsErrNeedToHaveNamespaceReadAccess(err))
_ = s.Close()
})
}
@ -72,10 +78,15 @@ func TestTeamNamespace_Create(t *testing.T) {
Right: RightAdmin,
}
db.LoadAndAssertFixtures(t)
allowed, _ := tn.CanCreate(u)
s := db.NewSession()
allowed, _ := tn.CanCreate(s, u)
assert.True(t, allowed)
err := tn.Create(u)
err := tn.Create(s, u)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_namespaces", map[string]interface{}{
"team_id": 1,
"namespace_id": 1,
@ -89,9 +100,11 @@ func TestTeamNamespace_Create(t *testing.T) {
Right: RightRead,
}
db.LoadAndAssertFixtures(t)
err := tn.Create(u)
s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrTeamAlreadyHasAccess(err))
_ = s.Close()
})
t.Run("invalid team right", func(t *testing.T) {
tn := TeamNamespace{
@ -100,9 +113,11 @@ func TestTeamNamespace_Create(t *testing.T) {
Right: RightUnknown,
}
db.LoadAndAssertFixtures(t)
err := tn.Create(u)
s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrInvalidRight(err))
_ = s.Close()
})
t.Run("nonexistant team", func(t *testing.T) {
tn := TeamNamespace{
@ -110,9 +125,11 @@ func TestTeamNamespace_Create(t *testing.T) {
NamespaceID: 1,
}
db.LoadAndAssertFixtures(t)
err := tn.Create(u)
s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
})
t.Run("nonexistant namespace", func(t *testing.T) {
tn := TeamNamespace{
@ -120,9 +137,11 @@ func TestTeamNamespace_Create(t *testing.T) {
NamespaceID: 9999,
}
db.LoadAndAssertFixtures(t)
err := tn.Create(u)
s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
})
}
@ -135,10 +154,14 @@ func TestTeamNamespace_Delete(t *testing.T) {
NamespaceID: 9,
}
db.LoadAndAssertFixtures(t)
allowed, _ := tn.CanDelete(u)
s := db.NewSession()
allowed, _ := tn.CanDelete(s, u)
assert.True(t, allowed)
err := tn.Delete()
err := tn.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "team_namespaces", map[string]interface{}{
"team_id": 7,
"namespace_id": 9,
@ -150,9 +173,11 @@ func TestTeamNamespace_Delete(t *testing.T) {
NamespaceID: 3,
}
db.LoadAndAssertFixtures(t)
err := tn.Delete()
s := db.NewSession()
err := tn.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
})
t.Run("nonexistant namespace", func(t *testing.T) {
tn := TeamNamespace{
@ -160,9 +185,11 @@ func TestTeamNamespace_Delete(t *testing.T) {
NamespaceID: 9999,
}
db.LoadAndAssertFixtures(t)
err := tn.Delete()
s := db.NewSession()
err := tn.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotHaveAccessToNamespace(err))
_ = s.Close()
})
}
@ -221,6 +248,7 @@ func TestTeamNamespace_Update(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := &TeamNamespace{
ID: tt.fields.ID,
@ -232,13 +260,17 @@ func TestTeamNamespace_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := tl.Update()
err := tl.Update(s)
if (err != nil) != tt.wantErr {
t.Errorf("TeamNamespace.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("TeamNamespace.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "team_namespaces", map[string]interface{}{
"team_id": tt.fields.TeamID,

View File

@ -36,8 +36,12 @@ func TestNamespace_Create(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
err := dummynamespace.Create(user1)
s := db.NewSession()
err := dummynamespace.Create(s, user1)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "namespaces", map[string]interface{}{
"title": "Test",
"description": "Lorem Ipsum",
@ -45,18 +49,22 @@ func TestNamespace_Create(t *testing.T) {
})
t.Run("no title", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n2 := Namespace{}
err := n2.Create(user1)
err := n2.Create(s, user1)
assert.Error(t, err)
assert.True(t, IsErrNamespaceNameCannotBeEmpty(err))
_ = s.Close()
})
t.Run("nonexistant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
nUser := &user.User{ID: 9482385}
dnsp2 := dummynamespace
err := dnsp2.Create(nUser)
err := dnsp2.Create(s, nUser)
assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
})
}
@ -64,28 +72,36 @@ func TestNamespace_ReadOne(t *testing.T) {
t.Run("normal", func(t *testing.T) {
n := &Namespace{ID: 1}
db.LoadAndAssertFixtures(t)
err := n.ReadOne()
s := db.NewSession()
err := n.ReadOne(s)
assert.NoError(t, err)
assert.Equal(t, n.Title, "testnamespace")
_ = s.Close()
})
t.Run("nonexistant", func(t *testing.T) {
n := &Namespace{ID: 99999}
db.LoadAndAssertFixtures(t)
err := n.ReadOne()
s := db.NewSession()
err := n.ReadOne(s)
assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
})
}
func TestNamespace_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{
ID: 1,
Title: "Lorem Ipsum",
}
err := n.Update()
err := n.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "namespaces", map[string]interface{}{
"id": 1,
"title": "Lorem Ipsum",
@ -93,56 +109,68 @@ func TestNamespace_Update(t *testing.T) {
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{
ID: 99999,
Title: "Lorem Ipsum",
}
err := n.Update()
err := n.Update(s)
assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
})
t.Run("nonexisting owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{
ID: 1,
Title: "Lorem Ipsum",
Owner: &user.User{ID: 99999},
}
err := n.Update()
err := n.Update(s)
assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
})
t.Run("no title", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{
ID: 1,
}
err := n.Update()
err := n.Update(s)
assert.Error(t, err)
assert.True(t, IsErrNamespaceNameCannotBeEmpty(err))
_ = s.Close()
})
}
func TestNamespace_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{
ID: 1,
}
err := n.Delete()
err := n.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "namespaces", map[string]interface{}{
"id": 1,
})
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{
ID: 9999,
}
err := n.Delete()
err := n.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
})
}
@ -152,9 +180,12 @@ func TestNamespace_ReadAll(t *testing.T) {
user11 := &user.User{ID: 11}
user12 := &user.User{ID: 12}
s := db.NewSession()
defer s.Close()
t.Run("normal", func(t *testing.T) {
n := &Namespace{}
nn, _, _, err := n.ReadAll(user1, "", 1, -1)
nn, _, _, err := n.ReadAll(s, user1, "", 1, -1)
assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces)
@ -174,7 +205,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{
NamespacesOnly: true,
}
nn, _, _, err := n.ReadAll(user1, "", 1, -1)
nn, _, _, err := n.ReadAll(s, user1, "", 1, -1)
assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces)
@ -188,7 +219,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{
NamespacesOnly: true,
}
nn, _, _, err := n.ReadAll(user7, "13,14", 1, -1)
nn, _, _, err := n.ReadAll(s, user7, "13,14", 1, -1)
assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces)
@ -200,7 +231,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{
NamespacesOnly: true,
}
nn, _, _, err := n.ReadAll(user1, "1,w", 1, -1)
nn, _, _, err := n.ReadAll(s, user1, "1,w", 1, -1)
assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces)
@ -211,7 +242,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{
IsArchived: true,
}
nn, _, _, err := n.ReadAll(user1, "", 1, -1)
nn, _, _, err := n.ReadAll(s, user1, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err)
assert.NotNil(t, namespaces)
@ -222,7 +253,7 @@ func TestNamespace_ReadAll(t *testing.T) {
})
t.Run("no favorites", func(t *testing.T) {
n := &Namespace{}
nn, _, _, err := n.ReadAll(user11, "", 1, -1)
nn, _, _, err := n.ReadAll(s, user11, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err)
// Assert the first namespace is not the favorites namespace
@ -230,7 +261,7 @@ func TestNamespace_ReadAll(t *testing.T) {
})
t.Run("no favorite tasks but namespace", func(t *testing.T) {
n := &Namespace{}
nn, _, _, err := n.ReadAll(user12, "", 1, -1)
nn, _, _, err := n.ReadAll(s, user12, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err)
// Assert the first namespace is the favorites namespace and contains lists
@ -239,7 +270,7 @@ func TestNamespace_ReadAll(t *testing.T) {
})
t.Run("no saved filters", func(t *testing.T) {
n := &Namespace{}
nn, _, _, err := n.ReadAll(user11, "", 1, -1)
nn, _, _, err := n.ReadAll(s, user11, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err)
// Assert the first namespace is not the favorites namespace

View File

@ -21,6 +21,7 @@ import (
user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// NamespaceUser represents a namespace <-> user relation
@ -64,7 +65,7 @@ func (NamespaceUser) TableName() string {
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/users [put]
func (nu *NamespaceUser) Create(a web.Auth) (err error) {
func (nu *NamespaceUser) Create(s *xorm.Session, a web.Auth) (err error) {
// Reset the id
nu.ID = 0
@ -74,13 +75,13 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
}
// Check if the namespace exists
l, err := GetNamespaceByID(nu.NamespaceID)
l, err := GetNamespaceByID(s, nu.NamespaceID)
if err != nil {
return
}
// Check if the user exists
user, err := user2.GetUserByUsername(nu.Username)
user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil {
return err
}
@ -92,7 +93,9 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
return ErrUserAlreadyHasNamespaceAccess{UserID: nu.UserID, NamespaceID: nu.NamespaceID}
}
exist, err := x.Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).Get(&NamespaceUser{})
exist, err := s.
Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).
Get(&NamespaceUser{})
if err != nil {
return
}
@ -101,7 +104,7 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
}
// Insert user <-> namespace relation
_, err = x.Insert(nu)
_, err = s.Insert(nu)
return
}
@ -119,17 +122,18 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "user or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/users/{userID} [delete]
func (nu *NamespaceUser) Delete() (err error) {
func (nu *NamespaceUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists
user, err := user2.GetUserByUsername(nu.Username)
user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil {
return
}
nu.UserID = user.ID
// Check if the user has access to the namespace
has, err := x.Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID).
has, err := s.
Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID).
Get(&NamespaceUser{})
if err != nil {
return
@ -138,7 +142,8 @@ func (nu *NamespaceUser) Delete() (err error) {
return ErrUserDoesNotHaveAccessToNamespace{NamespaceID: nu.NamespaceID, UserID: nu.UserID}
}
_, err = x.Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID).
_, err = s.
Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID).
Delete(&NamespaceUser{})
return
}
@ -158,10 +163,10 @@ func (nu *NamespaceUser) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the namespace."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/users [get]
func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (nu *NamespaceUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the namespace
l := Namespace{ID: nu.NamespaceID}
canRead, _, err := l.CanRead(a)
canRead, _, err := l.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -174,7 +179,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in
limit, start := getLimitFromPageIndex(page, perPage)
query := x.
query := s.
Join("INNER", "users_namespace", "user_id = users.id").
Where("users_namespace.namespace_id = ?", nu.NamespaceID).
Where("users.username LIKE ?", "%"+search+"%")
@ -191,7 +196,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in
u.Email = ""
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Join("INNER", "users_namespace", "user_id = users.id").
Where("users_namespace.namespace_id = ?", nu.NamespaceID).
Where("users.username LIKE ?", "%"+search+"%").
@ -215,7 +220,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in
// @Failure 404 {object} web.HTTPError "User or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/users/{userID} [post]
func (nu *NamespaceUser) Update() (err error) {
func (nu *NamespaceUser) Update(s *xorm.Session) (err error) {
// Check if the right is valid
if err := nu.Right.isValid(); err != nil {
@ -223,13 +228,13 @@ func (nu *NamespaceUser) Update() (err error) {
}
// Check if the user exists
user, err := user2.GetUserByUsername(nu.Username)
user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil {
return err
}
nu.UserID = user.ID
_, err = x.
_, err = s.
Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).
Cols("right").
Update(nu)

View File

@ -18,24 +18,25 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if the user can create a new user <-> namespace relation
func (nu *NamespaceUser) CanCreate(a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(a)
func (nu *NamespaceUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(s, a)
}
// CanDelete checks if the user can delete a user <-> namespace relation
func (nu *NamespaceUser) CanDelete(a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(a)
func (nu *NamespaceUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(s, a)
}
// CanUpdate checks if the user can update a user <-> namespace relation
func (nu *NamespaceUser) CanUpdate(a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(a)
func (nu *NamespaceUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(s, a)
}
func (nu *NamespaceUser) canDoNamespaceUser(a web.Auth) (bool, error) {
func (nu *NamespaceUser) canDoNamespaceUser(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: nu.NamespaceID}
return n.IsAdmin(a)
return n.IsAdmin(s, a)
}

View File

@ -80,6 +80,8 @@ func TestNamespaceUser_CanDoSomething(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
nu := &NamespaceUser{
ID: tt.fields.ID,
@ -91,13 +93,13 @@ func TestNamespaceUser_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
if got, _ := nu.CanCreate(tt.args.a); got != tt.want["CanCreate"] {
if got, _ := nu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("NamespaceUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
}
if got, _ := nu.CanDelete(tt.args.a); got != tt.want["CanDelete"] {
if got, _ := nu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("NamespaceUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
}
if got, _ := nu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] {
if got, _ := nu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("NamespaceUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
}
})

View File

@ -25,6 +25,7 @@ import (
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"github.com/stretchr/testify/assert"
"gopkg.in/d4l3k/messagediff.v1"
)
@ -108,6 +109,7 @@ func TestNamespaceUser_Create(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
un := &NamespaceUser{
ID: tt.fields.ID,
@ -119,13 +121,16 @@ func TestNamespaceUser_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := un.Create(tt.args.a)
err := un.Create(s, tt.args.a)
if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.Create() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("NamespaceUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "users_namespace", map[string]interface{}{
"user_id": tt.fields.UserID,
@ -211,6 +216,8 @@ func TestNamespaceUser_ReadAll(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
un := &NamespaceUser{
ID: tt.fields.ID,
@ -222,7 +229,7 @@ func TestNamespaceUser_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
got, _, _, err := un.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50)
got, _, _, err := un.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
return
@ -296,6 +303,7 @@ func TestNamespaceUser_Update(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
nu := &NamespaceUser{
ID: tt.fields.ID,
@ -307,13 +315,16 @@ func TestNamespaceUser_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := nu.Update()
err := nu.Update(s)
if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("NamespaceUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertExists(t, "users_namespace", map[string]interface{}{
"user_id": tt.fields.UserID,
@ -373,6 +384,7 @@ func TestNamespaceUser_Delete(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
nu := &NamespaceUser{
ID: tt.fields.ID,
@ -384,13 +396,16 @@ func TestNamespaceUser_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
err := nu.Delete()
err := nu.Delete(s)
if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("NamespaceUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
}
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr {
db.AssertMissing(t, "users_namespace", map[string]interface{}{
"user_id": tt.fields.UserID,

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// SavedFilter represents a saved bunch of filters
@ -48,14 +49,14 @@ type SavedFilter struct {
}
// TableName returns a better table name for saved filters
func (s *SavedFilter) TableName() string {
func (sf *SavedFilter) TableName() string {
return "saved_filters"
}
func (s *SavedFilter) getTaskCollection() *TaskCollection {
func (sf *SavedFilter) getTaskCollection() *TaskCollection {
// We're resetting the listID to return tasks from all lists
s.Filters.ListID = 0
return s.Filters
sf.Filters.ListID = 0
return sf.Filters
}
// Returns the saved filter ID from a list ID. Will not check if the filter actually exists.
@ -79,13 +80,13 @@ func getListIDFromSavedFilterID(filterID int64) (listID int64) {
return
}
func getSavedFiltersForUser(auth web.Auth) (filters []*SavedFilter, err error) {
func getSavedFiltersForUser(s *xorm.Session, auth web.Auth) (filters []*SavedFilter, err error) {
// Link shares can't view or modify saved filters, therefore we can error out right away
if _, is := auth.(*LinkSharing); is {
return nil, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID()}
}
err = x.Where("owner_id = ?", auth.GetID()).Find(&filters)
err = s.Where("owner_id = ?", auth.GetID()).Find(&filters)
return
}
@ -100,17 +101,17 @@ func getSavedFiltersForUser(auth web.Auth) (filters []*SavedFilter, err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter."
// @Failure 500 {object} models.Message "Internal error"
// @Router /filters [put]
func (s *SavedFilter) Create(auth web.Auth) error {
s.OwnerID = auth.GetID()
_, err := x.Insert(s)
func (sf *SavedFilter) Create(s *xorm.Session, auth web.Auth) error {
sf.OwnerID = auth.GetID()
_, err := s.Insert(sf)
return err
}
func getSavedFilterSimpleByID(id int64) (s *SavedFilter, err error) {
s = &SavedFilter{}
exists, err := x.
func getSavedFilterSimpleByID(s *xorm.Session, id int64) (sf *SavedFilter, err error) {
sf = &SavedFilter{}
exists, err := s.
Where("id = ?", id).
Get(s)
Get(sf)
if err != nil {
return nil, err
}
@ -132,10 +133,10 @@ func getSavedFilterSimpleByID(id int64) (s *SavedFilter, err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter."
// @Failure 500 {object} models.Message "Internal error"
// @Router /filters/{id} [get]
func (s *SavedFilter) ReadOne() error {
func (sf *SavedFilter) ReadOne(s *xorm.Session) error {
// s already contains almost the full saved filter from the rights check, we only need to add the user
u, err := user.GetUserByID(s.OwnerID)
s.Owner = u
u, err := user.GetUserByID(s, sf.OwnerID)
sf.Owner = u
return err
}
@ -152,15 +153,15 @@ func (s *SavedFilter) ReadOne() error {
// @Failure 404 {object} web.HTTPError "The saved filter does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /filters/{id} [post]
func (s *SavedFilter) Update() error {
_, err := x.
Where("id = ?", s.ID).
func (sf *SavedFilter) Update(s *xorm.Session) error {
_, err := s.
Where("id = ?", sf.ID).
Cols(
"title",
"description",
"filters",
).
Update(s)
Update(sf)
return err
}
@ -177,7 +178,9 @@ func (s *SavedFilter) Update() error {
// @Failure 404 {object} web.HTTPError "The saved filter does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /filters/{id} [delete]
func (s *SavedFilter) Delete() error {
_, err := x.Where("id = ?", s.ID).Delete(s)
func (sf *SavedFilter) Delete(s *xorm.Session) error {
_, err := s.
Where("id = ?", sf.ID).
Delete(sf)
return err
}

View File

@ -16,28 +16,31 @@
package models
import "code.vikunja.io/web"
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead checks if a user has the right to read a saved filter
func (s *SavedFilter) CanRead(auth web.Auth) (bool, int, error) {
can, err := s.canDoFilter(auth)
func (sf *SavedFilter) CanRead(s *xorm.Session, auth web.Auth) (bool, int, error) {
can, err := sf.canDoFilter(s, auth)
return can, int(RightAdmin), err
}
// CanDelete checks if a user has the right to delete a saved filter
func (s *SavedFilter) CanDelete(auth web.Auth) (bool, error) {
return s.canDoFilter(auth)
func (sf *SavedFilter) CanDelete(s *xorm.Session, auth web.Auth) (bool, error) {
return sf.canDoFilter(s, auth)
}
// CanUpdate checks if a user has the right to update a saved filter
func (s *SavedFilter) CanUpdate(auth web.Auth) (bool, error) {
func (sf *SavedFilter) CanUpdate(s *xorm.Session, auth web.Auth) (bool, error) {
// A normal check would replace the passed struct which in our case would override the values we want to update.
sf := &SavedFilter{ID: s.ID}
return sf.canDoFilter(auth)
sff := &SavedFilter{ID: sf.ID}
return sff.canDoFilter(s, auth)
}
// CanCreate checks if a user has the right to update a saved filter
func (s *SavedFilter) CanCreate(auth web.Auth) (bool, error) {
func (sf *SavedFilter) CanCreate(s *xorm.Session, auth web.Auth) (bool, error) {
if _, is := auth.(*LinkSharing); is {
return false, nil
}
@ -46,23 +49,23 @@ func (s *SavedFilter) CanCreate(auth web.Auth) (bool, error) {
}
// Helper function to check saved filter rights sind they all have the same logic
func (s *SavedFilter) canDoFilter(auth web.Auth) (can bool, err error) {
func (sf *SavedFilter) canDoFilter(s *xorm.Session, auth web.Auth) (can bool, err error) {
// Link shares can't view or modify saved filters, therefore we can error out right away
if _, is := auth.(*LinkSharing); is {
return false, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID(), SavedFilterID: s.ID}
return false, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID(), SavedFilterID: sf.ID}
}
sf, err := getSavedFilterSimpleByID(s.ID)
sff, err := getSavedFilterSimpleByID(s, sf.ID)
if err != nil {
return false, err
}
// Only owners are allowed to do something with a saved filter
if sf.OwnerID != auth.GetID() {
if sff.OwnerID != auth.GetID() {
return false, nil
}
*s = *sf
*sf = *sff
return true, nil
}

View File

@ -45,6 +45,9 @@ func TestSavedFilter_getFilterIDFromListID(t *testing.T) {
func TestSavedFilter_Create(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
Title: "test",
Description: "Lorem Ipsum dolor sit amet",
@ -52,9 +55,11 @@ func TestSavedFilter_Create(t *testing.T) {
}
u := &user.User{ID: 1}
err := sf.Create(u)
err := sf.Create(s, u)
assert.NoError(t, err)
assert.Equal(t, u.ID, sf.OwnerID)
err = s.Commit()
assert.NoError(t, err)
vals := map[string]interface{}{
"title": "'test'",
"description": "'Lorem Ipsum dolor sit amet'",
@ -62,7 +67,7 @@ func TestSavedFilter_Create(t *testing.T) {
"owner_id": 1,
}
// Postgres can't compare json values directly, see https://dba.stackexchange.com/a/106290/210721
if x.Dialect().URI().DBType == schemas.POSTGRES {
if db.Type() == schemas.POSTGRES {
vals["filters::jsonb"] = vals["filters"].(string) + "::jsonb"
delete(vals, "filters")
}
@ -72,26 +77,34 @@ func TestSavedFilter_Create(t *testing.T) {
func TestSavedFilter_ReadOne(t *testing.T) {
user1 := &user.User{ID: 1}
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
}
// canRead pre-populates the struct
_, _, err := sf.CanRead(user1)
_, _, err := sf.CanRead(s, user1)
assert.NoError(t, err)
err = sf.ReadOne()
err = sf.ReadOne(s)
assert.NoError(t, err)
assert.NotNil(t, sf.Owner)
}
func TestSavedFilter_Update(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "NewTitle",
Description: "", // Explicitly reset the description
Filters: &TaskCollection{},
}
err := sf.Update()
err := sf.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "saved_filters", map[string]interface{}{
"id": 1,
@ -102,10 +115,15 @@ func TestSavedFilter_Update(t *testing.T) {
func TestSavedFilter_Delete(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
}
err := sf.Delete()
err := sf.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "saved_filters", map[string]interface{}{
"id": 1,
@ -120,50 +138,65 @@ func TestSavedFilter_Rights(t *testing.T) {
t.Run("create", func(t *testing.T) {
// Should always be true
db.LoadAndAssertFixtures(t)
can, err := (&SavedFilter{}).CanCreate(user1)
s := db.NewSession()
defer s.Close()
can, err := (&SavedFilter{}).CanCreate(s, user1)
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("read", func(t *testing.T) {
t.Run("owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, max, err := sf.CanRead(user1)
can, max, err := sf.CanRead(s, user1)
assert.NoError(t, err)
assert.Equal(t, int(RightAdmin), max)
assert.True(t, can)
})
t.Run("not owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, _, err := sf.CanRead(user2)
can, _, err := sf.CanRead(s, user2)
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 9999,
Title: "Lorem",
}
can, _, err := sf.CanRead(user1)
can, _, err := sf.CanRead(s, user1)
assert.Error(t, err)
assert.True(t, IsErrSavedFilterDoesNotExist(err))
assert.False(t, can)
})
t.Run("link share", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, _, err := sf.CanRead(ls)
can, _, err := sf.CanRead(s, ls)
assert.Error(t, err)
assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err))
assert.False(t, can)
@ -172,42 +205,54 @@ func TestSavedFilter_Rights(t *testing.T) {
t.Run("update", func(t *testing.T) {
t.Run("owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, err := sf.CanUpdate(user1)
can, err := sf.CanUpdate(s, user1)
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("not owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, err := sf.CanUpdate(user2)
can, err := sf.CanUpdate(s, user2)
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 9999,
Title: "Lorem",
}
can, err := sf.CanUpdate(user1)
can, err := sf.CanUpdate(s, user1)
assert.Error(t, err)
assert.True(t, IsErrSavedFilterDoesNotExist(err))
assert.False(t, can)
})
t.Run("link share", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, err := sf.CanUpdate(ls)
can, err := sf.CanUpdate(s, ls)
assert.Error(t, err)
assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err))
assert.False(t, can)
@ -216,40 +261,52 @@ func TestSavedFilter_Rights(t *testing.T) {
t.Run("delete", func(t *testing.T) {
t.Run("owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
}
can, err := sf.CanDelete(user1)
can, err := sf.CanDelete(s, user1)
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("not owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
}
can, err := sf.CanDelete(user2)
can, err := sf.CanDelete(s, user2)
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 9999,
Title: "Lorem",
}
can, err := sf.CanDelete(user1)
can, err := sf.CanDelete(s, user1)
assert.Error(t, err)
assert.True(t, IsErrSavedFilterDoesNotExist(err))
assert.False(t, can)
})
t.Run("link share", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{
ID: 1,
Title: "Lorem",
}
can, err := sf.CanDelete(ls)
can, err := sf.CanDelete(s, ls)
assert.Error(t, err)
assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err))
assert.False(t, can)

View File

@ -46,9 +46,9 @@ type TaskAssigneeWithUser struct {
user.User `xorm:"extends"`
}
func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssigneeWithUser, err error) {
func getRawTaskAssigneesForTasks(s *xorm.Session, taskIDs []int64) (taskAssignees []*TaskAssigneeWithUser, err error) {
taskAssignees = []*TaskAssigneeWithUser{}
err = x.Table("task_assignees").
err = s.Table("task_assignees").
Select("task_id, users.*").
In("task_id", taskIDs).
Join("INNER", "users", "task_assignees.user_id = users.id").
@ -60,7 +60,7 @@ func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssignee
func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err error) {
// Load the current assignees
currentAssignees, err := getRawTaskAssigneesForTasks([]int64{t.ID})
currentAssignees, err := getRawTaskAssigneesForTasks(s, []int64{t.ID})
if err != nil {
return err
}
@ -118,8 +118,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err
}
// Get the list to perform later checks
list := List{ID: t.ListID}
err = list.GetSimpleByID()
list, err := GetListSimpleByID(s, t.ListID)
if err != nil {
return
}
@ -133,7 +132,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err
}
// Add the new assignee
err = t.addNewAssigneeByID(u.ID, &list)
err = t.addNewAssigneeByID(s, u.ID, list)
if err != nil {
return err
}
@ -141,7 +140,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err
t.setTaskAssignees(assignees)
err = updateListLastUpdated(&List{ID: t.ListID})
err = updateListLastUpdated(s, &List{ID: t.ListID})
return
}
@ -167,13 +166,13 @@ func (t *Task) setTaskAssignees(assignees []*user.User) {
// @Failure 403 {object} web.HTTPError "Not allowed to delete the assignee."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees/{userID} [delete]
func (la *TaskAssginee) Delete() (err error) {
_, err = x.Delete(&TaskAssginee{TaskID: la.TaskID, UserID: la.UserID})
func (la *TaskAssginee) Delete(s *xorm.Session) (err error) {
_, err = s.Delete(&TaskAssginee{TaskID: la.TaskID, UserID: la.UserID})
if err != nil {
return err
}
err = updateListByTaskID(la.TaskID)
err = updateListByTaskID(s, la.TaskID)
return
}
@ -190,25 +189,25 @@ func (la *TaskAssginee) Delete() (err error) {
// @Failure 400 {object} web.HTTPError "Invalid assignee object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees [put]
func (la *TaskAssginee) Create(a web.Auth) (err error) {
func (la *TaskAssginee) Create(s *xorm.Session, a web.Auth) (err error) {
// Get the list to perform later checks
list, err := GetListSimplByTaskID(la.TaskID)
list, err := GetListSimplByTaskID(s, la.TaskID)
if err != nil {
return
}
task := &Task{ID: la.TaskID}
return task.addNewAssigneeByID(la.UserID, list)
return task.addNewAssigneeByID(s, la.UserID, list)
}
func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
func (t *Task) addNewAssigneeByID(s *xorm.Session, newAssigneeID int64, list *List) (err error) {
// Check if the user exists and has access to the list
newAssignee, err := user.GetUserByID(newAssigneeID)
newAssignee, err := user.GetUserByID(s, newAssigneeID)
if err != nil {
return err
}
canRead, _, err := list.CanRead(newAssignee)
canRead, _, err := list.CanRead(s, newAssignee)
if err != nil {
return err
}
@ -216,7 +215,7 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
return ErrUserDoesNotHaveAccessToList{list.ID, newAssigneeID}
}
_, err = x.Insert(TaskAssginee{
_, err = s.Insert(TaskAssginee{
TaskID: t.ID,
UserID: newAssigneeID,
})
@ -224,7 +223,7 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
return err
}
err = updateListLastUpdated(&List{ID: t.ListID})
err = updateListLastUpdated(s, &List{ID: t.ListID})
return
}
@ -242,13 +241,13 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
// @Success 200 {array} user.User "The assignees"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees [get]
func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
task, err := GetListSimplByTaskID(la.TaskID)
func (la *TaskAssginee) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
task, err := GetListSimplByTaskID(s, la.TaskID)
if err != nil {
return nil, 0, 0, err
}
can, _, err := task.CanRead(a)
can, _, err := task.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -258,7 +257,7 @@ func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int
limit, start := getLimitFromPageIndex(page, perPage)
var taskAssignees []*user.User
query := x.Table("task_assignees").
query := s.Table("task_assignees").
Select("users.*").
Join("INNER", "users", "task_assignees.user_id = users.id").
Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%")
@ -270,7 +269,7 @@ func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int
return nil, 0, 0, err
}
numberOfTotalItems, err = x.Table("task_assignees").
numberOfTotalItems, err = s.Table("task_assignees").
Select("users.*").
Join("INNER", "users", "task_assignees.user_id = users.id").
Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%").
@ -301,14 +300,12 @@ type BulkAssignees struct {
// @Failure 400 {object} web.HTTPError "Invalid assignee object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees/bulk [post]
func (ba *BulkAssignees) Create(a web.Auth) (err error) {
s := x.NewSession()
task, err := GetTaskByIDSimple(ba.TaskID)
func (ba *BulkAssignees) Create(s *xorm.Session, a web.Auth) (err error) {
task, err := GetTaskByIDSimple(s, ba.TaskID)
if err != nil {
return
}
assignees, err := getRawTaskAssigneesForTasks([]int64{task.ID})
assignees, err := getRawTaskAssigneesForTasks(s, []int64{task.ID})
if err != nil {
return err
}
@ -317,10 +314,5 @@ func (ba *BulkAssignees) Create(a web.Auth) (err error) {
}
err = task.updateTaskAssignees(s, ba.Assignees)
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
return
}

View File

@ -18,28 +18,29 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if a user can add a new assignee
func (la *TaskAssginee) CanCreate(a web.Auth) (bool, error) {
return canDoTaskAssingee(la.TaskID, a)
func (la *TaskAssginee) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return canDoTaskAssingee(s, la.TaskID, a)
}
// CanCreate checks if a user can add a new assignee
func (ba *BulkAssignees) CanCreate(a web.Auth) (bool, error) {
return canDoTaskAssingee(ba.TaskID, a)
func (ba *BulkAssignees) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return canDoTaskAssingee(s, ba.TaskID, a)
}
// CanDelete checks if a user can delete an assignee
func (la *TaskAssginee) CanDelete(a web.Auth) (bool, error) {
return canDoTaskAssingee(la.TaskID, a)
func (la *TaskAssginee) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return canDoTaskAssingee(s, la.TaskID, a)
}
func canDoTaskAssingee(taskID int64, a web.Auth) (bool, error) {
func canDoTaskAssingee(s *xorm.Session, taskID int64, a web.Auth) (bool, error) {
// Check if the current user can edit the list
list, err := GetListSimplByTaskID(taskID)
list, err := GetListSimplByTaskID(s, taskID)
if err != nil {
return false, err
}
return list.CanUpdate(a)
return list.CanUpdate(s, a)
}

View File

@ -23,6 +23,7 @@ import (
"code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// TaskAttachment is the definition of a task attachment
@ -49,7 +50,7 @@ func (TaskAttachment) TableName() string {
// NewAttachment creates a new task attachment
// Note: I'm not sure if only accepting an io.ReadCloser and not an afero.File or os.File instead is a good way of doing things.
func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsize uint64, a web.Auth) error {
func (ta *TaskAttachment) NewAttachment(s *xorm.Session, f io.ReadCloser, realname string, realsize uint64, a web.Auth) error {
// Store the file
file, err := files.Create(f, realname, realsize, a)
@ -64,7 +65,7 @@ func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsi
// Add an entry to the db
ta.FileID = file.ID
ta.CreatedByID = a.GetID()
_, err = x.Insert(ta)
_, err = s.Insert(ta)
if err != nil {
// remove the uploaded file if adding it to the db fails
if err2 := file.Delete(); err2 != nil {
@ -77,8 +78,8 @@ func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsi
}
// ReadOne returns a task attachment
func (ta *TaskAttachment) ReadOne() (err error) {
exists, err := x.Where("id = ?", ta.ID).Get(ta)
func (ta *TaskAttachment) ReadOne(s *xorm.Session) (err error) {
exists, err := s.Where("id = ?", ta.ID).Get(ta)
if err != nil {
return
}
@ -110,12 +111,12 @@ func (ta *TaskAttachment) ReadOne() (err error) {
// @Failure 404 {object} models.Message "The task does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id}/attachments [get]
func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (ta *TaskAttachment) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
attachments := []*TaskAttachment{}
limit, start := getLimitFromPageIndex(page, perPage)
query := x.
query := s.
Where("task_id = ?", ta.TaskID)
if limit > 0 {
query = query.Limit(limit, start)
@ -133,13 +134,13 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i
}
fs := make(map[int64]*files.File)
err = x.In("id", fileIDs).Find(&fs)
err = s.In("id", fileIDs).Find(&fs)
if err != nil {
return nil, 0, 0, err
}
us := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&us)
err = s.In("id", userIDs).Find(&us)
if err != nil {
return nil, 0, 0, err
}
@ -153,7 +154,7 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i
r.CreatedBy = us[r.CreatedByID]
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Where("task_id = ?", ta.TaskID).
Count(&TaskAttachment{})
return attachments, len(attachments), numberOfTotalItems, err
@ -173,15 +174,17 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i
// @Failure 404 {object} models.Message "The task does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id}/attachments/{attachmentID} [delete]
func (ta *TaskAttachment) Delete() error {
func (ta *TaskAttachment) Delete(s *xorm.Session) error {
// Load the attachment
err := ta.ReadOne()
err := ta.ReadOne(s)
if err != nil && !files.IsErrFileDoesNotExist(err) {
return err
}
// Delete it
_, err = x.Where("task_id = ? AND id = ?", ta.TaskID, ta.ID).Delete(ta)
_, err = s.
Where("task_id = ? AND id = ?", ta.TaskID, ta.ID).
Delete(ta)
if err != nil {
return err
}
@ -195,9 +198,9 @@ func (ta *TaskAttachment) Delete() error {
return err
}
func getTaskAttachmentsByTaskIDs(taskIDs []int64) (attachments []*TaskAttachment, err error) {
func getTaskAttachmentsByTaskIDs(s *xorm.Session, taskIDs []int64) (attachments []*TaskAttachment, err error) {
attachments = []*TaskAttachment{}
err = x.
err = s.
In("task_id", taskIDs).
Find(&attachments)
if err != nil {
@ -213,13 +216,13 @@ func getTaskAttachmentsByTaskIDs(taskIDs []int64) (attachments []*TaskAttachment
// Get all files
fs := make(map[int64]*files.File)
err = x.In("id", fileIDs).Find(&fs)
err = s.In("id", fileIDs).Find(&fs)
if err != nil {
return
}
users := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&users)
err = s.In("id", userIDs).Find(&users)
if err != nil {
return
}

View File

@ -16,25 +16,28 @@
package models
import "code.vikunja.io/web"
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead checks if the user can see an attachment
func (ta *TaskAttachment) CanRead(a web.Auth) (bool, int, error) {
func (ta *TaskAttachment) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
t := &Task{ID: ta.TaskID}
return t.CanRead(a)
return t.CanRead(s, a)
}
// CanDelete checks if the user can delete an attachment
func (ta *TaskAttachment) CanDelete(a web.Auth) (bool, error) {
func (ta *TaskAttachment) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
t := &Task{ID: ta.TaskID}
return t.CanWrite(a)
return t.CanWrite(s, a)
}
// CanCreate checks if the user can create an attachment
func (ta *TaskAttachment) CanCreate(a web.Auth) (bool, error) {
t, err := GetTaskByIDSimple(ta.TaskID)
func (ta *TaskAttachment) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
t, err := GetTaskByIDSimple(s, ta.TaskID)
if err != nil {
return false, err
}
return t.CanCreate(a)
return t.CanCreate(s, a)
}

View File

@ -33,11 +33,14 @@ import (
func TestTaskAttachment_ReadOne(t *testing.T) {
t.Run("Normal File", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t)
ta := &TaskAttachment{
ID: 1,
}
err := ta.ReadOne()
err := ta.ReadOne(s)
assert.NoError(t, err)
assert.NotNil(t, ta.File)
assert.True(t, ta.File.ID == ta.FileID && ta.FileID != 0)
@ -54,21 +57,27 @@ func TestTaskAttachment_ReadOne(t *testing.T) {
})
t.Run("Nonexisting Attachment", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t)
ta := &TaskAttachment{
ID: 9999,
}
err := ta.ReadOne()
err := ta.ReadOne(s)
assert.Error(t, err)
assert.True(t, IsErrTaskAttachmentDoesNotExist(err))
})
t.Run("Existing Attachment, Nonexisting File", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t)
ta := &TaskAttachment{
ID: 2,
}
err := ta.ReadOne()
err := ta.ReadOne(s)
assert.Error(t, err)
assert.EqualError(t, err, "file 9999 does not exist")
})
@ -94,6 +103,9 @@ func (t *testfile) Close() error {
func TestTaskAttachment_NewAttachment(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t)
// Assert the file is being stored correctly
ta := TaskAttachment{
@ -104,7 +116,7 @@ func TestTaskAttachment_NewAttachment(t *testing.T) {
}
testuser := &user.User{ID: 1}
err := ta.NewAttachment(tf, "testfile", 100, testuser)
err := ta.NewAttachment(s, tf, "testfile", 100, testuser)
assert.NoError(t, err)
assert.NotEqual(t, 0, ta.FileID)
_, err = files.FileStat("files/" + strconv.FormatInt(ta.FileID, 10))
@ -125,9 +137,12 @@ func TestTaskAttachment_NewAttachment(t *testing.T) {
func TestTaskAttachment_ReadAll(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t)
ta := &TaskAttachment{TaskID: 1}
as, _, _, err := ta.ReadAll(&user.User{ID: 1}, "", 0, 50)
as, _, _, err := ta.ReadAll(s, &user.User{ID: 1}, "", 0, 50)
attachments, _ := as.([]*TaskAttachment)
assert.NoError(t, err)
assert.Len(t, attachments, 2)
@ -136,10 +151,13 @@ func TestTaskAttachment_ReadAll(t *testing.T) {
func TestTaskAttachment_Delete(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t)
t.Run("Normal", func(t *testing.T) {
ta := &TaskAttachment{ID: 1}
err := ta.Delete()
err := ta.Delete(s)
assert.NoError(t, err)
// Check if the file itself was deleted
_, err = files.FileStat("/1") // The new file has the id 2 since it's the second attachment
@ -148,14 +166,14 @@ func TestTaskAttachment_Delete(t *testing.T) {
t.Run("Nonexisting", func(t *testing.T) {
files.InitTestFileFixtures(t)
ta := &TaskAttachment{ID: 9999}
err := ta.Delete()
err := ta.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrTaskAttachmentDoesNotExist(err))
})
t.Run("Existing attachment, nonexisting file", func(t *testing.T) {
files.InitTestFileFixtures(t)
ta := &TaskAttachment{ID: 2}
err := ta.Delete()
err := ta.Delete(s)
assert.NoError(t, err)
})
}
@ -165,15 +183,21 @@ func TestTaskAttachment_Rights(t *testing.T) {
t.Run("Can Read", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 1}
can, _, err := ta.CanRead(u)
can, _, err := ta.CanRead(s, u)
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("Forbidden", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 14}
can, _, err := ta.CanRead(u)
can, _, err := ta.CanRead(s, u)
assert.NoError(t, err)
assert.False(t, can)
})
@ -181,22 +205,31 @@ func TestTaskAttachment_Rights(t *testing.T) {
t.Run("Can Delete", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 1}
can, err := ta.CanDelete(u)
can, err := ta.CanDelete(s, u)
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("Forbidden, no access", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 14}
can, err := ta.CanDelete(u)
can, err := ta.CanDelete(s, u)
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("Forbidden, shared read only", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 15}
can, err := ta.CanDelete(u)
can, err := ta.CanDelete(s, u)
assert.NoError(t, err)
assert.False(t, can)
})
@ -204,22 +237,31 @@ func TestTaskAttachment_Rights(t *testing.T) {
t.Run("Can Create", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 1}
can, err := ta.CanCreate(u)
can, err := ta.CanCreate(s, u)
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("Forbidden, no access", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 14}
can, err := ta.CanCreate(u)
can, err := ta.CanCreate(s, u)
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("Forbidden, shared read only", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 15}
can, err := ta.CanCreate(u)
can, err := ta.CanCreate(s, u)
assert.NoError(t, err)
assert.False(t, can)
})

View File

@ -20,6 +20,7 @@ package models
import (
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// TaskCollection is a struct used to hold filter details and not clutter the Task struct with information not related to actual tasks.
@ -100,17 +101,17 @@ func validateTaskField(fieldName string) error {
// @Success 200 {array} models.Task "The tasks"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/tasks [get]
func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
func (tf *TaskCollection) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
// If the list id is < -1 this means we're dealing with a saved filter - in that case we get and populate the filter
// -1 is the favorites list which works as intended
if tf.ListID < -1 {
s, err := getSavedFilterSimpleByID(getSavedFilterIDFromListID(tf.ListID))
sf, err := getSavedFilterSimpleByID(s, getSavedFilterIDFromListID(tf.ListID))
if err != nil {
return nil, 0, 0, err
}
return s.getTaskCollection().ReadAll(a, search, page, perPage)
return sf.getTaskCollection().ReadAll(s, a, search, page, perPage)
}
if len(tf.SortByArr) > 0 {
@ -156,28 +157,30 @@ func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage i
shareAuth, is := a.(*LinkSharing)
if is {
list := &List{ID: shareAuth.ListID}
err := list.GetSimpleByID()
list, err := GetListSimpleByID(s, shareAuth.ListID)
if err != nil {
return nil, 0, 0, err
}
return getTasksForLists([]*List{list}, a, taskopts)
return getTasksForLists(s, []*List{list}, a, taskopts)
}
// If the list ID is not set, we get all tasks for the user.
// This allows to use this function in Task.ReadAll with a possibility to deprecate the latter at some point.
if tf.ListID == 0 {
tf.Lists, _, _, err = getRawListsForUser(&listOptions{
user: &user.User{ID: a.GetID()},
page: -1,
})
tf.Lists, _, _, err = getRawListsForUser(
s,
&listOptions{
user: &user.User{ID: a.GetID()},
page: -1,
},
)
if err != nil {
return nil, 0, 0, err
}
} else {
// Check the list exists and the user has acess on it
list := &List{ID: tf.ListID}
canRead, _, err := list.CanRead(a)
canRead, _, err := list.CanRead(s, a)
if err != nil {
return nil, 0, 0, err
}
@ -187,5 +190,5 @@ func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage i
tf.Lists = []*List{{ID: tf.ListID}}
}
return getTasksForLists(tf.Lists, a, taskopts)
return getTasksForLists(s, tf.Lists, a, taskopts)
}

View File

@ -986,6 +986,8 @@ func TestTaskCollection_ReadAll(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
lt := &TaskCollection{
ListID: tt.fields.ListID,
@ -1000,7 +1002,7 @@ func TestTaskCollection_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights,
}
got, _, _, err := lt.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50)
got, _, _, err := lt.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr {
t.Errorf("Test %s, Task.ReadAll() error = %v, wantErr %v", tt.name, err, tt.wantErr)
return

View File

@ -17,28 +17,31 @@
package models
import "code.vikunja.io/web"
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead checks if a user can read a comment
func (tc *TaskComment) CanRead(a web.Auth) (bool, int, error) {
func (tc *TaskComment) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
t := Task{ID: tc.TaskID}
return t.CanRead(a)
return t.CanRead(s, a)
}
// CanDelete checks if a user can delete a comment
func (tc *TaskComment) CanDelete(a web.Auth) (bool, error) {
func (tc *TaskComment) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
t := Task{ID: tc.TaskID}
return t.CanWrite(a)
return t.CanWrite(s, a)
}
// CanUpdate checks if a user can update a comment
func (tc *TaskComment) CanUpdate(a web.Auth) (bool, error) {
func (tc *TaskComment) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
t := Task{ID: tc.TaskID}
return t.CanWrite(a)
return t.CanWrite(s, a)
}
// CanCreate checks if a user can create a new comment
func (tc *TaskComment) CanCreate(a web.Auth) (bool, error) {
func (tc *TaskComment) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
t := Task{ID: tc.TaskID}
return t.CanWrite(a)
return t.CanWrite(s, a)
}

View File

@ -20,6 +20,8 @@ package models
import (
"time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
)
@ -57,19 +59,19 @@ func (tc *TaskComment) TableName() string {
// @Failure 400 {object} web.HTTPError "Invalid task comment object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments [put]
func (tc *TaskComment) Create(a web.Auth) (err error) {
func (tc *TaskComment) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the task exists
_, err = GetTaskSimple(&Task{ID: tc.TaskID})
_, err = GetTaskSimple(s, &Task{ID: tc.TaskID})
if err != nil {
return err
}
tc.AuthorID = a.GetID()
_, err = x.Insert(tc)
_, err = s.Insert(tc)
if err != nil {
return
}
tc.Author, err = user.GetUserByID(a.GetID())
tc.Author, err = user.GetUserByID(s, a.GetID())
return
}
@ -87,8 +89,11 @@ func (tc *TaskComment) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "The task comment was not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments/{commentID} [delete]
func (tc *TaskComment) Delete() error {
deleted, err := x.ID(tc.ID).NoAutoCondition().Delete(tc)
func (tc *TaskComment) Delete(s *xorm.Session) error {
deleted, err := s.
ID(tc.ID).
NoAutoCondition().
Delete(tc)
if deleted == 0 {
return ErrTaskCommentDoesNotExist{ID: tc.ID}
}
@ -109,8 +114,11 @@ func (tc *TaskComment) Delete() error {
// @Failure 404 {object} web.HTTPError "The task comment was not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments/{commentID} [post]
func (tc *TaskComment) Update() error {
updated, err := x.ID(tc.ID).Cols("comment").Update(tc)
func (tc *TaskComment) Update(s *xorm.Session) error {
updated, err := s.
ID(tc.ID).
Cols("comment").
Update(tc)
if updated == 0 {
return ErrTaskCommentDoesNotExist{ID: tc.ID}
}
@ -131,8 +139,8 @@ func (tc *TaskComment) Update() error {
// @Failure 404 {object} web.HTTPError "The task comment was not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments/{commentID} [get]
func (tc *TaskComment) ReadOne() (err error) {
exists, err := x.Get(tc)
func (tc *TaskComment) ReadOne(s *xorm.Session) (err error) {
exists, err := s.Get(tc)
if err != nil {
return
}
@ -145,7 +153,7 @@ func (tc *TaskComment) ReadOne() (err error) {
// Get the author
author := &user.User{}
_, err = x.
_, err = s.
Where("id = ?", tc.AuthorID).
Get(author)
tc.Author = author
@ -163,10 +171,10 @@ func (tc *TaskComment) ReadOne() (err error) {
// @Success 200 {array} models.TaskComment "The array with all task comments"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments [get]
func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (tc *TaskComment) ReadAll(s *xorm.Session, auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the task
canRead, _, err := tc.CanRead(auth)
canRead, _, err := tc.CanRead(s, auth)
if err != nil {
return nil, 0, 0, err
}
@ -184,7 +192,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i
limit, start := getLimitFromPageIndex(page, perPage)
comments := []*TaskComment{}
query := x.
query := s.
Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%").
Join("LEFT", "users", "users.id = task_comments.author_id")
if limit > 0 {
@ -197,7 +205,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i
// Get all authors
authors := make(map[int64]*user.User)
err = x.
err = s.
Select("users.*").
Table("task_comments").
Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%").
@ -211,7 +219,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i
comment.Author = authors[comment.AuthorID]
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%").
Count(&TaskCommentWithAuthor{})
return comments, len(comments), numberOfTotalItems, err

View File

@ -28,14 +28,20 @@ func TestTaskComment_Create(t *testing.T) {
u := &user.User{ID: 1}
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{
Comment: "test",
TaskID: 1,
}
err := tc.Create(u)
err := tc.Create(s, u)
assert.NoError(t, err)
assert.Equal(t, "test", tc.Comment)
assert.Equal(t, int64(1), tc.Author.ID)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "task_comments", map[string]interface{}{
"id": tc.ID,
"author_id": u.ID,
@ -45,11 +51,14 @@ func TestTaskComment_Create(t *testing.T) {
})
t.Run("nonexisting task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{
Comment: "test",
TaskID: 99999,
}
err := tc.Create(u)
err := tc.Create(s, u)
assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err))
})
@ -58,17 +67,26 @@ func TestTaskComment_Create(t *testing.T) {
func TestTaskComment_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 1}
err := tc.Delete()
err := tc.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "task_comments", map[string]interface{}{
"id": 1,
})
})
t.Run("nonexisting comment", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 9999}
err := tc.Delete()
err := tc.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrTaskCommentDoesNotExist(err))
})
@ -77,12 +95,18 @@ func TestTaskComment_Delete(t *testing.T) {
func TestTaskComment_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{
ID: 1,
Comment: "testing",
}
err := tc.Update()
err := tc.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "task_comments", map[string]interface{}{
"id": 1,
"comment": "testing",
@ -90,10 +114,13 @@ func TestTaskComment_Update(t *testing.T) {
})
t.Run("nonexisting comment", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{
ID: 9999,
}
err := tc.Update()
err := tc.Update(s)
assert.Error(t, err)
assert.True(t, IsErrTaskCommentDoesNotExist(err))
})
@ -102,16 +129,22 @@ func TestTaskComment_Update(t *testing.T) {
func TestTaskComment_ReadOne(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 1}
err := tc.ReadOne()
err := tc.ReadOne(s)
assert.NoError(t, err)
assert.Equal(t, "Lorem Ipsum Dolor Sit Amet", tc.Comment)
assert.NotEmpty(t, tc.Author.ID)
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 9999}
err := tc.ReadOne()
err := tc.ReadOne(s)
assert.Error(t, err)
assert.True(t, IsErrTaskCommentDoesNotExist(err))
})
@ -120,9 +153,12 @@ func TestTaskComment_ReadOne(t *testing.T) {
func TestTaskComment_ReadAll(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{TaskID: 1}
u := &user.User{ID: 1}
result, resultCount, total, err := tc.ReadAll(u, "", 0, -1)
result, resultCount, total, err := tc.ReadAll(s, u, "", 0, -1)
resultComment := result.([]*TaskComment)
assert.NoError(t, err)
assert.Equal(t, 1, resultCount)
@ -133,9 +169,12 @@ func TestTaskComment_ReadAll(t *testing.T) {
})
t.Run("no access to task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{TaskID: 14}
u := &user.User{ID: 1}
_, _, _, err := tc.ReadAll(u, "", 0, -1)
_, _, _, err := tc.ReadAll(s, u, "", 0, -1)
assert.Error(t, err)
assert.True(t, IsErrGenericForbidden(err))
})

View File

@ -20,6 +20,8 @@ package models
import (
"time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
)
@ -117,7 +119,7 @@ type RelatedTaskMap map[RelationKind][]*Task
// @Failure 400 {object} web.HTTPError "Invalid task relation object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/relations [put]
func (rel *TaskRelation) Create(a web.Auth) error {
func (rel *TaskRelation) Create(s *xorm.Session, a web.Auth) error {
// Check if both tasks are the same
if rel.TaskID == rel.OtherTaskID {
@ -128,7 +130,7 @@ func (rel *TaskRelation) Create(a web.Auth) error {
}
// Check if the relation already exists, in one form or the other.
exists, err := x.
exists, err := s.
Where("(task_id = ? AND other_task_id = ? AND relation_kind = ?) OR (task_id = ? AND other_task_id = ? AND relation_kind = ?)",
rel.TaskID, rel.OtherTaskID, rel.RelationKind, rel.TaskID, rel.OtherTaskID, rel.RelationKind).
Exist(rel)
@ -180,7 +182,7 @@ func (rel *TaskRelation) Create(a web.Auth) error {
}
// Finally insert everything
_, err = x.Insert(&[]*TaskRelation{
_, err = s.Insert(&[]*TaskRelation{
rel,
otherRelation,
})
@ -200,9 +202,9 @@ func (rel *TaskRelation) Create(a web.Auth) error {
// @Failure 404 {object} web.HTTPError "The task relation was not found."
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/relations [delete]
func (rel *TaskRelation) Delete() error {
func (rel *TaskRelation) Delete(s *xorm.Session) error {
// Check if the relation exists
exists, err := x.
exists, err := s.
Cols("task_id", "other_task_id", "relation_kind").
Get(rel)
if err != nil {
@ -216,6 +218,6 @@ func (rel *TaskRelation) Delete() error {
}
}
_, err = x.Delete(rel)
_, err = s.Delete(rel)
return err
}

View File

@ -17,17 +17,20 @@
package models
import "code.vikunja.io/web"
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanDelete checks if a user can delete a task relation
func (rel *TaskRelation) CanDelete(a web.Auth) (bool, error) {
func (rel *TaskRelation) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
// A user can delete a relation if it can update the base task
baseTask := &Task{ID: rel.TaskID}
return baseTask.CanUpdate(a)
return baseTask.CanUpdate(s, a)
}
// CanCreate checks if a user can create a new relation between two relations
func (rel *TaskRelation) CanCreate(a web.Auth) (bool, error) {
func (rel *TaskRelation) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
// Check if the relation kind is valid
if !rel.RelationKind.isValid() {
return false, ErrInvalidRelationKind{Kind: rel.RelationKind}
@ -35,14 +38,14 @@ func (rel *TaskRelation) CanCreate(a web.Auth) (bool, error) {
// Needs have write access to the base task and at least read access to the other task
baseTask := &Task{ID: rel.TaskID}
has, err := baseTask.CanUpdate(a)
has, err := baseTask.CanUpdate(s, a)
if err != nil || !has {
return false, err
}
// We explicitly don't check if the two tasks are on the same list.
otherTask := &Task{ID: rel.OtherTaskID}
has, _, err = otherTask.CanRead(a)
has, _, err = otherTask.CanRead(s, a)
if err != nil {
return false, err
}

View File

@ -28,13 +28,17 @@ import (
func TestTaskRelation_Create(t *testing.T) {
t.Run("Normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 2,
RelationKind: RelationKindSubtask,
}
err := rel.Create(&user.User{ID: 1})
err := rel.Create(s, &user.User{ID: 1})
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "task_relations", map[string]interface{}{
"task_id": 1,
@ -45,13 +49,17 @@ func TestTaskRelation_Create(t *testing.T) {
})
t.Run("Two Tasks In Different Lists", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 13,
RelationKind: RelationKindSubtask,
}
err := rel.Create(&user.User{ID: 1})
err := rel.Create(s, &user.User{ID: 1})
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "task_relations", map[string]interface{}{
"task_id": 1,
@ -62,24 +70,28 @@ func TestTaskRelation_Create(t *testing.T) {
})
t.Run("Already Existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 29,
RelationKind: RelationKindSubtask,
}
err := rel.Create(&user.User{ID: 1})
err := rel.Create(s, &user.User{ID: 1})
assert.Error(t, err)
assert.True(t, IsErrRelationAlreadyExists(err))
})
t.Run("Same Task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 1,
}
err := rel.Create(&user.User{ID: 1})
err := rel.Create(s, &user.User{ID: 1})
assert.Error(t, err)
assert.True(t, IsErrRelationTasksCannotBeTheSame(err))
})
@ -88,13 +100,17 @@ func TestTaskRelation_Create(t *testing.T) {
func TestTaskRelation_Delete(t *testing.T) {
t.Run("Normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 29,
RelationKind: RelationKindSubtask,
}
err := rel.Delete()
err := rel.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "task_relations", map[string]interface{}{
"task_id": 1,
@ -104,13 +120,15 @@ func TestTaskRelation_Delete(t *testing.T) {
})
t.Run("Not existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 9999,
OtherTaskID: 3,
RelationKind: RelationKindSubtask,
}
err := rel.Delete()
err := rel.Delete(s)
assert.Error(t, err)
assert.True(t, IsErrRelationDoesNotExist(err))
})
@ -119,86 +137,100 @@ func TestTaskRelation_Delete(t *testing.T) {
func TestTaskRelation_CanCreate(t *testing.T) {
t.Run("Normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 2,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("Two tasks on different lists", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 13,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err)
assert.True(t, can)
})
t.Run("No update rights on base task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 14,
OtherTaskID: 1,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("No update rights on base task, but read rights", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 15,
OtherTaskID: 1,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("No read rights on other task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 14,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err)
assert.False(t, can)
})
t.Run("Nonexisting base task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 999999,
OtherTaskID: 1,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err))
assert.False(t, can)
})
t.Run("Nonexisting other task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{
TaskID: 1,
OtherTaskID: 999999,
RelationKind: RelationKindSubtask,
}
can, err := rel.CanCreate(&user.User{ID: 1})
can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err))
assert.False(t, can)

View File

@ -19,6 +19,9 @@ package models
import (
"time"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/cron"
"code.vikunja.io/api/pkg/log"
@ -44,10 +47,10 @@ type taskUser struct {
User *user.User `xorm:"extends"`
}
func getTaskUsersForTasks(taskIDs []int64) (taskUsers []*taskUser, err error) {
func getTaskUsersForTasks(s *xorm.Session, taskIDs []int64) (taskUsers []*taskUser, err error) {
// Get all creators of tasks
creators := make(map[int64]*user.User, len(taskIDs))
err = x.
err = s.
Select("users.id, users.username, users.email, users.name").
Join("LEFT", "tasks", "tasks.created_by_id = users.id").
In("tasks.id", taskIDs).
@ -58,13 +61,13 @@ func getTaskUsersForTasks(taskIDs []int64) (taskUsers []*taskUser, err error) {
return
}
assignees, err := getRawTaskAssigneesForTasks(taskIDs)
assignees, err := getRawTaskAssigneesForTasks(s, taskIDs)
if err != nil {
return
}
taskMap := make(map[int64]*Task, len(taskIDs))
err = x.In("id", taskIDs).Find(&taskMap)
err = s.In("id", taskIDs).Find(&taskMap)
if err != nil {
return
}
@ -106,6 +109,8 @@ func RegisterReminderCron() {
log.Debugf("[Task Reminder Cron] Timezone is %s", tz)
s := db.NewSession()
err := cron.Schedule("* * * * *", func() {
// By default, time.Now() includes nanoseconds which we don't save. That results in getting the wrong dates,
// so we make sure the time we use to get the reminders don't contain nanoseconds.
@ -116,7 +121,7 @@ func RegisterReminderCron() {
log.Debugf("[Task Reminder Cron] Looking for reminders between %s and %s to send...", now, nextMinute)
reminders := []*TaskReminder{}
err := x.
err := s.
Where("reminder >= ? and reminder < ?", now.Format(dbFormat), nextMinute.Format(dbFormat)).
Find(&reminders)
if err != nil {
@ -136,7 +141,7 @@ func RegisterReminderCron() {
taskIDs = append(taskIDs, r.TaskID)
}
users, err := getTaskUsersForTasks(taskIDs)
users, err := getTaskUsersForTasks(s, taskIDs)
if err != nil {
log.Errorf("[Task Reminder Cron] Could not get task users to send them reminders: %s", err)
return

View File

@ -22,6 +22,8 @@ import (
"strconv"
"time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/user"
@ -153,7 +155,7 @@ type taskOptions struct {
// @Success 200 {array} models.Task "The tasks"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/all [get]
func (t *Task) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
func (t *Task) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
return nil, 0, 0, nil
}
@ -209,7 +211,7 @@ func getFilterCondForSeparateTable(table string, concat taskFilterConcatinator,
}
//nolint:gocyclo
func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) {
func getRawTasksForLists(s *xorm.Session, lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) {
// If the user does not have any lists, don't try to get any tasks
if len(lists) == 0 {
@ -253,7 +255,7 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
// Postgres sorts by default entries with null values after ones with values.
// To make that consistent with the sort order we have and other dbms, we're adding a separate clause here.
if x.Dialect().URI().DBType == schemas.POSTGRES {
if db.Type() == schemas.POSTGRES {
if param.orderBy == orderAscending {
orderby += " NULLS FIRST"
}
@ -324,9 +326,7 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
}
// Then return all tasks for that lists
query := x.NewSession().
OrderBy(orderby)
queryCount := x.NewSession()
var where builder.Cond
if len(opts.search) > 0 {
// Postgres' is case sensitive by default.
@ -335,11 +335,9 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
// See https://stackoverflow.com/q/7005302/10924593
// Seems okay to use that now, we may need to find a better solution overall in the future.
if config.DatabaseType.GetString() == "postgres" {
query = query.Where("title ILIKE ?", "%"+opts.search+"%")
queryCount = queryCount.Where("title ILIKE ?", "%"+opts.search+"%")
where = builder.Expr("title ILIKE ?", "%"+opts.search+"%")
} else {
query = query.Where("title LIKE ?", "%"+opts.search+"%")
queryCount = queryCount.Where("title LIKE ?", "%"+opts.search+"%")
where = &builder.Like{"title", "%" + opts.search + "%"}
}
}
@ -352,10 +350,13 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
if hasFavoriteLists {
// Make sure users can only see their favorites
userLists, _, _, err := getRawListsForUser(&listOptions{
user: &user.User{ID: a.GetID()},
page: -1,
})
userLists, _, _, err := getRawListsForUser(
s,
&listOptions{
user: &user.User{ID: a.GetID()},
page: -1,
},
)
if err != nil {
return nil, 0, 0, err
}
@ -399,32 +400,31 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
filters = append(filters, cond)
}
query = query.Where(listCond)
queryCount = queryCount.Where(listCond)
var filterCond builder.Cond
if len(filters) > 0 {
if opts.filterConcat == filterConcatOr {
query = query.Where(builder.Or(filters...))
queryCount = queryCount.Where(builder.Or(filters...))
filterCond = builder.Or(filters...)
}
if opts.filterConcat == filterConcatAnd {
query = query.Where(builder.And(filters...))
queryCount = queryCount.Where(builder.And(filters...))
filterCond = builder.And(filters...)
}
}
limit, start := getLimitFromPageIndex(opts.page, opts.perPage)
cond := builder.And(listCond, where, filterCond)
query := s.Where(cond)
if limit > 0 {
query = query.Limit(limit, start)
}
tasks = []*Task{}
err = query.Find(&tasks)
err = query.OrderBy(orderby).Find(&tasks)
if err != nil {
return nil, 0, 0, err
}
queryCount := s.Where(cond)
totalItems, err = queryCount.
Count(&Task{})
if err != nil {
@ -434,9 +434,9 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
return tasks, len(tasks), totalItems, nil
}
func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) {
func getTasksForLists(s *xorm.Session, lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) {
tasks, resultCount, totalItems, err = getRawTasksForLists(lists, a, opts)
tasks, resultCount, totalItems, err = getRawTasksForLists(s, lists, a, opts)
if err != nil {
return nil, 0, 0, err
}
@ -446,7 +446,7 @@ func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Ta
taskMap[t.ID] = t
}
err = addMoreInfoToTasks(taskMap)
err = addMoreInfoToTasks(s, taskMap)
if err != nil {
return nil, 0, 0, err
}
@ -455,18 +455,18 @@ func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Ta
}
// GetTaskByIDSimple returns a raw task without extra data by the task ID
func GetTaskByIDSimple(taskID int64) (task Task, err error) {
func GetTaskByIDSimple(s *xorm.Session, taskID int64) (task Task, err error) {
if taskID < 1 {
return Task{}, ErrTaskDoesNotExist{taskID}
}
return GetTaskSimple(&Task{ID: taskID})
return GetTaskSimple(s, &Task{ID: taskID})
}
// GetTaskSimple returns a raw task without extra data
func GetTaskSimple(t *Task) (task Task, err error) {
func GetTaskSimple(s *xorm.Session, t *Task) (task Task, err error) {
task = *t
exists, err := x.Get(&task)
exists, err := s.Get(&task)
if err != nil {
return Task{}, err
}
@ -478,14 +478,14 @@ func GetTaskSimple(t *Task) (task Task, err error) {
}
// GetTasksByIDs returns all tasks for a list of ids
func (bt *BulkTask) GetTasksByIDs() (err error) {
func (bt *BulkTask) GetTasksByIDs(s *xorm.Session) (err error) {
for _, id := range bt.IDs {
if id < 1 {
return ErrTaskDoesNotExist{id}
}
}
err = x.In("id", bt.IDs).Find(&bt.Tasks)
err = s.In("id", bt.IDs).Find(&bt.Tasks)
if err != nil {
return
}
@ -494,9 +494,9 @@ func (bt *BulkTask) GetTasksByIDs() (err error) {
}
// GetTasksByUIDs gets all tasks from a bunch of uids
func GetTasksByUIDs(uids []string) (tasks []*Task, err error) {
func GetTasksByUIDs(s *xorm.Session, uids []string) (tasks []*Task, err error) {
tasks = []*Task{}
err = x.In("uid", uids).Find(&tasks)
err = s.In("uid", uids).Find(&tasks)
if err != nil {
return
}
@ -506,13 +506,13 @@ func GetTasksByUIDs(uids []string) (tasks []*Task, err error) {
taskMap[t.ID] = t
}
err = addMoreInfoToTasks(taskMap)
err = addMoreInfoToTasks(s, taskMap)
return
}
func getRemindersForTasks(taskIDs []int64) (reminders []*TaskReminder, err error) {
func getRemindersForTasks(s *xorm.Session, taskIDs []int64) (reminders []*TaskReminder, err error) {
reminders = []*TaskReminder{}
err = x.In("task_id", taskIDs).Find(&reminders)
err = s.In("task_id", taskIDs).Find(&reminders)
return
}
@ -521,8 +521,8 @@ func (t *Task) setIdentifier(list *List) {
}
// Get all assignees
func addAssigneesToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
taskAssignees, err := getRawTaskAssigneesForTasks(taskIDs)
func addAssigneesToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
taskAssignees, err := getRawTaskAssigneesForTasks(s, taskIDs)
if err != nil {
return
}
@ -538,8 +538,8 @@ func addAssigneesToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
}
// Get all labels for all the tasks
func addLabelsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
labels, _, _, err := getLabelsByTaskIDs(&LabelByTaskIDsOptions{
func addLabelsToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
labels, _, _, err := getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
TaskIDs: taskIDs,
Page: -1,
})
@ -556,8 +556,8 @@ func addLabelsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
}
// Get task attachments
func addAttachmentsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
attachments, err := getTaskAttachmentsByTaskIDs(taskIDs)
func addAttachmentsToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
attachments, err := getTaskAttachmentsByTaskIDs(s, taskIDs)
if err != nil {
return
}
@ -568,11 +568,11 @@ func addAttachmentsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error)
return
}
func getTaskReminderMap(taskIDs []int64) (taskReminders map[int64][]time.Time, err error) {
func getTaskReminderMap(s *xorm.Session, taskIDs []int64) (taskReminders map[int64][]time.Time, err error) {
taskReminders = make(map[int64][]time.Time)
// Get all reminders and put them in a map to have it easier later
reminders, err := getRemindersForTasks(taskIDs)
reminders, err := getRemindersForTasks(s, taskIDs)
if err != nil {
return
}
@ -584,9 +584,9 @@ func getTaskReminderMap(taskIDs []int64) (taskReminders map[int64][]time.Time, e
return
}
func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
func addRelatedTasksToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
relatedTasks := []*TaskRelation{}
err = x.In("task_id", taskIDs).Find(&relatedTasks)
err = s.In("task_id", taskIDs).Find(&relatedTasks)
if err != nil {
return
}
@ -597,7 +597,7 @@ func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error
relatedTaskIDs = append(relatedTaskIDs, rt.OtherTaskID)
}
fullRelatedTasks := make(map[int64]*Task)
err = x.In("id", relatedTaskIDs).Find(&fullRelatedTasks)
err = s.In("id", relatedTaskIDs).Find(&fullRelatedTasks)
if err != nil {
return
}
@ -614,7 +614,7 @@ func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error
// This function takes a map with pointers and returns a slice with pointers to tasks
// It adds more stuff like assignees/labels/etc to a bunch of tasks
func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) {
func addMoreInfoToTasks(s *xorm.Session, taskMap map[int64]*Task) (err error) {
// No need to iterate over users and stuff if the list doesn't have tasks
if len(taskMap) == 0 {
@ -631,33 +631,33 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) {
listIDs = append(listIDs, i.ListID)
}
err = addAssigneesToTasks(taskIDs, taskMap)
err = addAssigneesToTasks(s, taskIDs, taskMap)
if err != nil {
return
}
err = addLabelsToTasks(taskIDs, taskMap)
err = addLabelsToTasks(s, taskIDs, taskMap)
if err != nil {
return
}
err = addAttachmentsToTasks(taskIDs, taskMap)
err = addAttachmentsToTasks(s, taskIDs, taskMap)
if err != nil {
return
}
users, err := user.GetUsersByIDs(userIDs)
users, err := user.GetUsersByIDs(s, userIDs)
if err != nil {
return
}
taskReminders, err := getTaskReminderMap(taskIDs)
taskReminders, err := getTaskReminderMap(s, taskIDs)
if err != nil {
return err
}
// Get all identifiers
lists, err := GetListsByIDs(listIDs)
lists, err := GetListsByIDs(s, listIDs)
if err != nil {
return err
}
@ -679,7 +679,7 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) {
}
// Get all related tasks
err = addRelatedTasksToTasks(taskIDs, taskMap)
err = addRelatedTasksToTasks(s, taskIDs, taskMap)
return
}
@ -739,14 +739,8 @@ func checkBucketLimit(s *xorm.Session, t *Task, bucket *Bucket) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [put]
func (t *Task) Create(a web.Auth) (err error) {
s := x.NewSession()
err = createTask(s, t, a, true)
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
func (t *Task) Create(s *xorm.Session, a web.Auth) (err error) {
return createTask(s, t, a, true)
}
func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err error) {
@ -759,16 +753,16 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
}
// Check if the list exists
l := &List{ID: t.ListID}
if err = l.getSimpleByID(s); err != nil {
return
l, err := GetListSimpleByID(s, t.ListID)
if err != nil {
return err
}
if _, is := a.(*LinkSharing); is {
// A negative user id indicates user share links
t.CreatedByID = a.GetID() * -1
} else {
u, err := user.GetUserByID(a.GetID())
u, err := user.GetUserByID(s, a.GetID())
if err != nil {
return err
}
@ -834,7 +828,7 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
t.setIdentifier(l)
err = updateListLastUpdatedS(s, &List{ID: t.ListID})
err = updateListLastUpdated(s, &List{ID: t.ListID})
return
}
@ -853,21 +847,17 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id} [post]
//nolint:gocyclo
func (t *Task) Update() (err error) {
s := x.NewSession()
func (t *Task) Update(s *xorm.Session) (err error) {
// Check if the task exists and get the old values
ot, err := GetTaskByIDSimple(t.ID)
ot, err := GetTaskByIDSimple(s, t.ID)
if err != nil {
_ = s.Rollback()
return
}
// Get the reminders
reminders, err := getRemindersForTasks([]int64{t.ID})
reminders, err := getRemindersForTasks(s, []int64{t.ID})
if err != nil {
_ = s.Rollback()
return
}
@ -881,20 +871,17 @@ func (t *Task) Update() (err error) {
// Update the assignees
if err := ot.updateTaskAssignees(s, t.Assignees); err != nil {
_ = s.Rollback()
return err
}
// Update the reminders
if err := ot.updateReminders(s, t.Reminders); err != nil {
_ = s.Rollback()
return err
}
// If there is a bucket set, make sure they belong to the same list as the task
err = checkBucketAndTaskBelongToSameList(s, &ot, t.BucketID)
if err != nil {
_ = s.Rollback()
return
}
@ -923,7 +910,6 @@ func (t *Task) Update() (err error) {
if t.BucketID == 0 || (t.ListID != 0 && ot.ListID != t.ListID) {
bucket, err = getDefaultBucket(s, t.ListID)
if err != nil {
_ = s.Rollback()
return err
}
t.BucketID = bucket.ID
@ -934,7 +920,6 @@ func (t *Task) Update() (err error) {
latestTask := &Task{}
_, err = s.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask)
if err != nil {
_ = s.Rollback()
return err
}
@ -946,7 +931,6 @@ func (t *Task) Update() (err error) {
// Only check the bucket limit if the task is being moved between buckets, allow reordering the task within a bucket
if t.BucketID != ot.BucketID {
if err := checkBucketLimit(s, t, bucket); err != nil {
_ = s.Rollback()
return err
}
}
@ -972,7 +956,6 @@ func (t *Task) Update() (err error) {
// Which is why we merge the actual task struct with the one we got from the db
// The user struct overrides values in the actual one.
if err := mergo.Merge(&ot, t, mergo.WithOverride); err != nil {
_ = s.Rollback()
return err
}
@ -1034,7 +1017,6 @@ func (t *Task) Update() (err error) {
Update(ot)
*t = ot
if err != nil {
_ = s.Rollback()
return err
}
// Get the task updated timestamp in a new struct - if we'd just try to put it into t which we already have, it
@ -1042,17 +1024,11 @@ func (t *Task) Update() (err error) {
nt := &Task{}
_, err = s.ID(t.ID).Get(nt)
if err != nil {
_ = s.Rollback()
return err
}
t.Updated = nt.Updated
err = updateListLastUpdatedS(s, &List{ID: t.ListID})
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
return updateListLastUpdated(s, &List{ID: t.ListID})
}
// This helper function updates the reminders, doneAt, start and end dates of the *old* task
@ -1174,7 +1150,7 @@ func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err erro
t.Reminders = nil
}
err = updateListLastUpdatedS(s, &List{ID: t.ListID})
err = updateListLastUpdated(s, &List{ID: t.ListID})
return
}
@ -1190,20 +1166,20 @@ func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err erro
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id} [delete]
func (t *Task) Delete() (err error) {
func (t *Task) Delete(s *xorm.Session) (err error) {
if _, err = x.ID(t.ID).Delete(Task{}); err != nil {
if _, err = s.ID(t.ID).Delete(Task{}); err != nil {
return err
}
// Delete assignees
if _, err = x.Where("task_id = ?", t.ID).Delete(TaskAssginee{}); err != nil {
if _, err = s.Where("task_id = ?", t.ID).Delete(TaskAssginee{}); err != nil {
return err
}
metrics.UpdateCount(-1, metrics.TaskCountKey)
err = updateListLastUpdated(&List{ID: t.ListID})
err = updateListLastUpdated(s, &List{ID: t.ListID})
return
}
@ -1219,16 +1195,16 @@ func (t *Task) Delete() (err error) {
// @Failure 404 {object} models.Message "Task not found"
// @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{ID} [get]
func (t *Task) ReadOne() (err error) {
func (t *Task) ReadOne(s *xorm.Session) (err error) {
taskMap := make(map[int64]*Task, 1)
taskMap[t.ID] = &Task{}
*taskMap[t.ID], err = GetTaskByIDSimple(t.ID)
*taskMap[t.ID], err = GetTaskByIDSimple(s, t.ID)
if err != nil {
return
}
err = addMoreInfoToTasks(taskMap)
err = addMoreInfoToTasks(s, taskMap)
if err != nil {
return
}

View File

@ -18,47 +18,48 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanDelete checks if the user can delete an task
func (t *Task) CanDelete(a web.Auth) (bool, error) {
return t.canDoTask(a)
func (t *Task) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return t.canDoTask(s, a)
}
// CanUpdate determines if a user has the right to update a list task
func (t *Task) CanUpdate(a web.Auth) (bool, error) {
return t.canDoTask(a)
func (t *Task) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return t.canDoTask(s, a)
}
// CanCreate determines if a user has the right to create a list task
func (t *Task) CanCreate(a web.Auth) (bool, error) {
func (t *Task) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
// A user can do a task if he has write acces to its list
l := &List{ID: t.ListID}
return l.CanWrite(a)
return l.CanWrite(s, a)
}
// CanRead determines if a user can read a task
func (t *Task) CanRead(a web.Auth) (canRead bool, maxRight int, err error) {
func (t *Task) CanRead(s *xorm.Session, a web.Auth) (canRead bool, maxRight int, err error) {
// Get the task, error out if it doesn't exist
*t, err = GetTaskByIDSimple(t.ID)
*t, err = GetTaskByIDSimple(s, t.ID)
if err != nil {
return
}
// A user can read a task if it has access to the list
l := &List{ID: t.ListID}
return l.CanRead(a)
return l.CanRead(s, a)
}
// CanWrite checks if a user has write access to a task
func (t *Task) CanWrite(a web.Auth) (canWrite bool, err error) {
return t.canDoTask(a)
func (t *Task) CanWrite(s *xorm.Session, a web.Auth) (canWrite bool, err error) {
return t.canDoTask(s, a)
}
// Helper function to check if a user can do stuff on a list task
func (t *Task) canDoTask(a web.Auth) (bool, error) {
func (t *Task) canDoTask(s *xorm.Session, a web.Auth) (bool, error) {
// Get the task
ot, err := GetTaskByIDSimple(t.ID)
ot, err := GetTaskByIDSimple(s, t.ID)
if err != nil {
return false, err
}
@ -66,7 +67,7 @@ func (t *Task) canDoTask(a web.Auth) (bool, error) {
// Check if we're moving the task into a different list to check if the user has sufficient rights for that on the new list
if t.ListID != 0 && t.ListID != ot.ListID {
newList := &List{ID: t.ListID}
can, err := newList.CanWrite(a)
can, err := newList.CanWrite(s, a)
if err != nil {
return false, err
}
@ -77,5 +78,5 @@ func (t *Task) canDoTask(a web.Auth) (bool, error) {
// A user can do a task if it has write acces to its list
l := &List{ID: ot.ListID}
return l.CanWrite(a)
return l.CanWrite(s, a)
}

View File

@ -36,12 +36,15 @@ func TestTask_Create(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
Title: "Lorem",
Description: "Lorem Ipsum Dolor",
ListID: 1,
}
err := task.Create(usr)
err := task.Create(s, usr)
assert.NoError(t, err)
// Assert getting a uid
assert.NotEmpty(t, task.UID)
@ -50,6 +53,9 @@ func TestTask_Create(t *testing.T) {
assert.Equal(t, int64(18), task.Index)
// Assert moving it into the default bucket
assert.Equal(t, int64(1), task.BucketID)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "tasks", map[string]interface{}{
"id": task.ID,
"title": "Lorem",
@ -62,47 +68,59 @@ func TestTask_Create(t *testing.T) {
})
t.Run("empty title", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
Title: "",
Description: "Lorem Ipsum Dolor",
ListID: 1,
}
err := task.Create(usr)
err := task.Create(s, usr)
assert.Error(t, err)
assert.True(t, IsErrTaskCannotBeEmpty(err))
})
t.Run("nonexistant list", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
Title: "Test",
Description: "Lorem Ipsum Dolor",
ListID: 9999999,
}
err := task.Create(usr)
err := task.Create(s, usr)
assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err))
})
t.Run("noneixtant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
nUser := &user.User{ID: 99999999}
task := &Task{
Title: "Test",
Description: "Lorem Ipsum Dolor",
ListID: 1,
}
err := task.Create(nUser)
err := task.Create(s, nUser)
assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err))
})
t.Run("full bucket", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
Title: "Lorem",
Description: "Lorem Ipsum Dolor",
ListID: 1,
BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3
}
err := task.Create(usr)
err := task.Create(s, usr)
assert.Error(t, err)
assert.True(t, IsErrBucketLimitExceeded(err))
})
@ -111,14 +129,20 @@ func TestTask_Create(t *testing.T) {
func TestTask_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
ID: 1,
Title: "test10000",
Description: "Lorem Ipsum Dolor",
ListID: 1,
}
err := task.Update()
err := task.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "tasks", map[string]interface{}{
"id": 1,
"title": "test10000",
@ -128,18 +152,24 @@ func TestTask_Update(t *testing.T) {
})
t.Run("nonexistant task", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
ID: 9999999,
Title: "test10000",
Description: "Lorem Ipsum Dolor",
ListID: 1,
}
err := task.Update()
err := task.Update(s)
assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err))
})
t.Run("full bucket", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
ID: 1,
Title: "test10000",
@ -147,12 +177,15 @@ func TestTask_Update(t *testing.T) {
ListID: 1,
BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3
}
err := task.Update()
err := task.Update(s)
assert.Error(t, err)
assert.True(t, IsErrBucketLimitExceeded(err))
})
t.Run("full bucket but not changing the bucket", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
ID: 4,
Title: "test10000",
@ -161,7 +194,7 @@ func TestTask_Update(t *testing.T) {
ListID: 1,
BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3
}
err := task.Update()
err := task.Update(s)
assert.NoError(t, err)
})
}
@ -169,11 +202,17 @@ func TestTask_Update(t *testing.T) {
func TestTask_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{
ID: 1,
}
err := task.Delete()
err := task.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "tasks", map[string]interface{}{
"id": 1,
})
@ -183,6 +222,9 @@ func TestTask_Delete(t *testing.T) {
func TestUpdateDone(t *testing.T) {
t.Run("marking a task as done", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
oldTask := &Task{Done: false}
newTask := &Task{Done: true}
updateDone(oldTask, newTask)
@ -190,6 +232,9 @@ func TestUpdateDone(t *testing.T) {
})
t.Run("unmarking a task as done", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
oldTask := &Task{Done: true}
newTask := &Task{Done: false}
updateDone(oldTask, newTask)
@ -397,15 +442,21 @@ func TestUpdateDone(t *testing.T) {
func TestTask_ReadOne(t *testing.T) {
t.Run("default", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ID: 1}
err := task.ReadOne()
err := task.ReadOne(s)
assert.NoError(t, err)
assert.Equal(t, "task #1", task.Title)
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ID: 99999}
err := task.ReadOne()
err := task.ReadOne(s)
assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err))
})

View File

@ -19,6 +19,7 @@ package models
import (
user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// Create implements the create method to assign a user to a team
@ -35,23 +36,24 @@ import (
// @Failure 403 {object} web.HTTPError "The user does not have access to the team"
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id}/members [put]
func (tm *TeamMember) Create(a web.Auth) (err error) {
func (tm *TeamMember) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the team extst
_, err = GetTeamByID(tm.TeamID)
_, err = GetTeamByID(s, tm.TeamID)
if err != nil {
return
}
// Check if the user exists
user, err := user2.GetUserByUsername(tm.Username)
user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil {
return
}
tm.UserID = user.ID
// Check if that user is already part of the team
exists, err := x.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
exists, err := s.
Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
Get(&TeamMember{})
if err != nil {
return
@ -61,7 +63,7 @@ func (tm *TeamMember) Create(a web.Auth) (err error) {
}
// Insert the user
_, err = x.Insert(tm)
_, err = s.Insert(tm)
return
}
@ -76,9 +78,9 @@ func (tm *TeamMember) Create(a web.Auth) (err error) {
// @Success 200 {object} models.Message "The user was successfully removed from the team."
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id}/members/{userID} [delete]
func (tm *TeamMember) Delete() (err error) {
func (tm *TeamMember) Delete(s *xorm.Session) (err error) {
total, err := x.Where("team_id = ?", tm.TeamID).Count(&TeamMember{})
total, err := s.Where("team_id = ?", tm.TeamID).Count(&TeamMember{})
if err != nil {
return
}
@ -87,13 +89,13 @@ func (tm *TeamMember) Delete() (err error) {
}
// Find the numeric user id
user, err := user2.GetUserByUsername(tm.Username)
user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil {
return
}
tm.UserID = user.ID
_, err = x.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).Delete(&TeamMember{})
_, err = s.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).Delete(&TeamMember{})
return
}
@ -108,9 +110,9 @@ func (tm *TeamMember) Delete() (err error) {
// @Success 200 {object} models.Message "The member right was successfully changed."
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id}/members/{userID}/admin [post]
func (tm *TeamMember) Update() (err error) {
func (tm *TeamMember) Update(s *xorm.Session) (err error) {
// Find the numeric user id
user, err := user2.GetUserByUsername(tm.Username)
user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil {
return
}
@ -118,7 +120,7 @@ func (tm *TeamMember) Update() (err error) {
// Get the full member object and change the admin right
ttm := &TeamMember{}
_, err = x.
_, err = s.
Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
Get(ttm)
if err != nil {
@ -127,7 +129,7 @@ func (tm *TeamMember) Update() (err error) {
ttm.Admin = !ttm.Admin
// Do the update
_, err = x.
_, err = s.
Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
Cols("admin").
Update(ttm)

View File

@ -18,32 +18,34 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if the user can add a new tem member
func (tm *TeamMember) CanCreate(a web.Auth) (bool, error) {
return tm.IsAdmin(a)
func (tm *TeamMember) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return tm.IsAdmin(s, a)
}
// CanDelete checks if the user can delete a new team member
func (tm *TeamMember) CanDelete(a web.Auth) (bool, error) {
return tm.IsAdmin(a)
func (tm *TeamMember) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return tm.IsAdmin(s, a)
}
// CanUpdate checks if the user can modify a team member's right
func (tm *TeamMember) CanUpdate(a web.Auth) (bool, error) {
return tm.IsAdmin(a)
func (tm *TeamMember) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return tm.IsAdmin(s, a)
}
// IsAdmin checks if the user is team admin
func (tm *TeamMember) IsAdmin(a web.Auth) (bool, error) {
func (tm *TeamMember) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
// Don't allow anything if we're dealing with a list share here
if _, is := a.(*LinkSharing); is {
return false, nil
}
// A user can add a member to a team if he is admin of that team
exists, err := x.Where("user_id = ? AND team_id = ? AND admin = ?", a.GetID(), tm.TeamID, true).
exists, err := s.
Where("user_id = ? AND team_id = ? AND admin = ?", a.GetID(), tm.TeamID, true).
Get(&TeamMember{})
return exists, err
}

View File

@ -32,12 +32,18 @@ func TestTeamMember_Create(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 1,
Username: "user3",
}
err := tm.Create(doer)
err := tm.Create(s, doer)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_members", map[string]interface{}{
"id": tm.ID,
"team_id": 1,
@ -46,31 +52,40 @@ func TestTeamMember_Create(t *testing.T) {
})
t.Run("already existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 1,
Username: "user1",
}
err := tm.Create(doer)
err := tm.Create(s, doer)
assert.Error(t, err)
assert.True(t, IsErrUserIsMemberOfTeam(err))
})
t.Run("nonexisting user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 1,
Username: "nonexistinguser",
}
err := tm.Create(doer)
err := tm.Create(s, doer)
assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err))
})
t.Run("nonexisting team", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 9999999,
Username: "user1",
}
err := tm.Create(doer)
err := tm.Create(s, doer)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
})
@ -79,12 +94,18 @@ func TestTeamMember_Create(t *testing.T) {
func TestTeamMember_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 1,
Username: "user1",
}
err := tm.Delete()
err := tm.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "team_members", map[string]interface{}{
"team_id": 1,
"user_id": 1,
@ -95,14 +116,20 @@ func TestTeamMember_Delete(t *testing.T) {
func TestTeamMember_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 1,
Username: "user1",
Admin: true,
}
err := tm.Update()
err := tm.Update(s)
assert.NoError(t, err)
assert.False(t, tm.Admin) // Since this endpoint toggles the right, we should get a false for admin back.
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_members", map[string]interface{}{
"team_id": 1,
"user_id": 1,
@ -113,14 +140,20 @@ func TestTeamMember_Update(t *testing.T) {
// should ignore what was passed.
t.Run("explicitly false in payload", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{
TeamID: 1,
Username: "user1",
Admin: true,
}
err := tm.Update()
err := tm.Update(s)
assert.NoError(t, err)
assert.False(t, tm.Admin)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_members", map[string]interface{}{
"team_id": 1,
"user_id": 1,

View File

@ -19,6 +19,8 @@ package models
import (
"time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web"
@ -54,10 +56,6 @@ func (Team) TableName() string {
return "teams"
}
// AfterLoad gets the created by user object
func (t *Team) AfterLoad() {
}
// TeamMember defines the relationship between a user and a team
type TeamMember struct {
// The unique, numeric id of this team member relation.
@ -92,14 +90,14 @@ type TeamUser struct {
}
// GetTeamByID gets a team by its ID
func GetTeamByID(id int64) (team *Team, err error) {
func GetTeamByID(s *xorm.Session, id int64) (team *Team, err error) {
if id < 1 {
return team, ErrTeamDoesNotExist{id}
}
t := Team{}
exists, err := x.
exists, err := s.
Where("id = ?", id).
Get(&t)
if err != nil {
@ -110,7 +108,7 @@ func GetTeamByID(id int64) (team *Team, err error) {
}
teamSlice := []*Team{&t}
err = addMoreInfoToTeams(teamSlice)
err = addMoreInfoToTeams(s, teamSlice)
if err != nil {
return
}
@ -120,7 +118,7 @@ func GetTeamByID(id int64) (team *Team, err error) {
return
}
func addMoreInfoToTeams(teams []*Team) (err error) {
func addMoreInfoToTeams(s *xorm.Session, teams []*Team) (err error) {
// Put the teams in a map to make assigning more info to it more efficient
teamMap := make(map[int64]*Team, len(teams))
var teamIDs []int64
@ -133,7 +131,8 @@ func addMoreInfoToTeams(teams []*Team) (err error) {
// Get all owners and team members
users := make(map[int64]*TeamUser)
err = x.Select("*").
err = s.
Select("*").
Table("users").
Join("LEFT", "team_members", "team_members.user_id = users.id").
Join("LEFT", "teams", "team_members.team_id = teams.id").
@ -178,8 +177,8 @@ func addMoreInfoToTeams(teams []*Team) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the team"
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id} [get]
func (t *Team) ReadOne() (err error) {
team, err := GetTeamByID(t.ID)
func (t *Team) ReadOne(s *xorm.Session) (err error) {
team, err := GetTeamByID(s, t.ID)
if team != nil {
*t = *team
}
@ -199,7 +198,7 @@ func (t *Team) ReadOne() (err error) {
// @Success 200 {array} models.Team "The teams."
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams [get]
func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
func (t *Team) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
if _, is := a.(*LinkSharing); is {
return nil, 0, 0, ErrGenericForbidden{}
}
@ -207,7 +206,7 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result
limit, start := getLimitFromPageIndex(page, perPage)
all := []*Team{}
query := x.Select("teams.*").
query := s.Select("teams.*").
Table("teams").
Join("INNER", "team_members", "team_members.team_id = teams.id").
Where("team_members.user_id = ?", a.GetID()).
@ -220,12 +219,12 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result
return nil, 0, 0, err
}
err = addMoreInfoToTeams(all)
err = addMoreInfoToTeams(s, all)
if err != nil {
return nil, 0, 0, err
}
numberOfTotalItems, err = x.
numberOfTotalItems, err = s.
Table("teams").
Join("INNER", "team_members", "team_members.team_id = teams.id").
Where("team_members.user_id = ?", a.GetID()).
@ -246,7 +245,7 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result
// @Failure 400 {object} web.HTTPError "Invalid team object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams [put]
func (t *Team) Create(a web.Auth) (err error) {
func (t *Team) Create(s *xorm.Session, a web.Auth) (err error) {
doer, err := user.GetFromAuth(a)
if err != nil {
return err
@ -260,14 +259,14 @@ func (t *Team) Create(a web.Auth) (err error) {
t.CreatedByID = doer.ID
t.CreatedBy = doer
_, err = x.Insert(t)
_, err = s.Insert(t)
if err != nil {
return
}
// Insert the current user as member and admin
tm := TeamMember{TeamID: t.ID, Username: doer.Username, Admin: true}
if err = tm.Create(doer); err != nil {
if err = tm.Create(s, doer); err != nil {
return err
}
@ -286,28 +285,28 @@ func (t *Team) Create(a web.Auth) (err error) {
// @Failure 400 {object} web.HTTPError "Invalid team object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id} [delete]
func (t *Team) Delete() (err error) {
func (t *Team) Delete(s *xorm.Session) (err error) {
// Delete the team
_, err = x.ID(t.ID).Delete(&Team{})
_, err = s.ID(t.ID).Delete(&Team{})
if err != nil {
return
}
// Delete team members
_, err = x.Where("team_id = ?", t.ID).Delete(&TeamMember{})
_, err = s.Where("team_id = ?", t.ID).Delete(&TeamMember{})
if err != nil {
return
}
// Delete team <-> namespace relations
_, err = x.Where("team_id = ?", t.ID).Delete(&TeamNamespace{})
_, err = s.Where("team_id = ?", t.ID).Delete(&TeamNamespace{})
if err != nil {
return
}
// Delete team <-> lists relations
_, err = x.Where("team_id = ?", t.ID).Delete(&TeamList{})
_, err = s.Where("team_id = ?", t.ID).Delete(&TeamList{})
if err != nil {
return
}
@ -329,25 +328,25 @@ func (t *Team) Delete() (err error) {
// @Failure 400 {object} web.HTTPError "Invalid team object provided."
// @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id} [post]
func (t *Team) Update() (err error) {
func (t *Team) Update(s *xorm.Session) (err error) {
// Check if we have a name
if t.Name == "" {
return ErrTeamNameCannotBeEmpty{}
}
// Check if the team exists
_, err = GetTeamByID(t.ID)
_, err = GetTeamByID(s, t.ID)
if err != nil {
return
}
_, err = x.ID(t.ID).Update(t)
_, err = s.ID(t.ID).Update(t)
if err != nil {
return
}
// Get the newly updated team
team, err := GetTeamByID(t.ID)
team, err := GetTeamByID(s, t.ID)
if team != nil {
*t = *team
}

View File

@ -18,10 +18,11 @@ package models
import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if the user can create a new team
func (t *Team) CanCreate(a web.Auth) (bool, error) {
func (t *Team) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is {
return false, nil
}
@ -31,39 +32,40 @@ func (t *Team) CanCreate(a web.Auth) (bool, error) {
}
// CanUpdate checks if the user can update a team
func (t *Team) CanUpdate(a web.Auth) (bool, error) {
return t.IsAdmin(a)
func (t *Team) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return t.IsAdmin(s, a)
}
// CanDelete checks if a user can delete a team
func (t *Team) CanDelete(a web.Auth) (bool, error) {
return t.IsAdmin(a)
func (t *Team) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return t.IsAdmin(s, a)
}
// IsAdmin returns true when the user is admin of a team
func (t *Team) IsAdmin(a web.Auth) (bool, error) {
func (t *Team) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
// Don't do anything if we're deadling with a link share auth here
if _, is := a.(*LinkSharing); is {
return false, nil
}
// Check if the team exists to be able to return a proper error message if not
_, err := GetTeamByID(t.ID)
_, err := GetTeamByID(s, t.ID)
if err != nil {
return false, err
}
return x.Where("team_id = ?", t.ID).
return s.Where("team_id = ?", t.ID).
And("user_id = ?", a.GetID()).
And("admin = ?", true).
Get(&TeamMember{})
}
// CanRead returns true if the user has read access to the team
func (t *Team) CanRead(a web.Auth) (bool, int, error) {
func (t *Team) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
// Check if the user is in the team
tm := &TeamMember{}
can, err := x.Where("team_id = ?", t.ID).
can, err := s.
Where("team_id = ?", t.ID).
And("user_id = ?", a.GetID()).
Get(tm)

View File

@ -82,6 +82,8 @@ func TestTeam_CanDoSomething(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &Team{
ID: tt.fields.ID,
@ -96,19 +98,19 @@ func TestTeam_CanDoSomething(t *testing.T) {
Rights: tt.fields.Rights,
}
if got, _ := tm.CanCreate(tt.args.a); got != tt.want["CanCreate"] { // CanCreate is currently always true
if got, _ := tm.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] { // CanCreate is currently always true
t.Errorf("Team.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
}
if got, _ := tm.CanDelete(tt.args.a); got != tt.want["CanDelete"] {
if got, _ := tm.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("Team.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
}
if got, _ := tm.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] {
if got, _ := tm.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("Team.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
}
if got, _, _ := tm.CanRead(tt.args.a); got != tt.want["CanRead"] {
if got, _, _ := tm.CanRead(s, tt.args.a); got != tt.want["CanRead"] {
t.Errorf("Team.CanRead() = %v, want %v", got, tt.want["CanRead"])
}
if got, _ := tm.IsAdmin(tt.args.a); got != tt.want["IsAdmin"] {
if got, _ := tm.IsAdmin(s, tt.args.a); got != tt.want["IsAdmin"] {
t.Errorf("Team.IsAdmin() = %v, want %v", got, tt.want["IsAdmin"])
}
})

View File

@ -32,11 +32,16 @@ func TestTeam_Create(t *testing.T) {
}
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{
Name: "Testteam293",
Description: "Lorem Ispum",
}
err := team.Create(doer)
err := team.Create(s, doer)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "teams", map[string]interface{}{
"id": team.ID,
@ -46,8 +51,11 @@ func TestTeam_Create(t *testing.T) {
})
t.Run("empty name", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{}
err := team.Create(doer)
err := team.Create(s, doer)
assert.Error(t, err)
assert.True(t, IsErrTeamNameCannotBeEmpty(err))
})
@ -56,8 +64,11 @@ func TestTeam_Create(t *testing.T) {
func TestTeam_ReadOne(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ID: 1}
err := team.ReadOne()
err := team.ReadOne(s)
assert.NoError(t, err)
assert.Equal(t, "testteam1", team.Name)
assert.Equal(t, "Lorem Ipsum", team.Description)
@ -66,15 +77,21 @@ func TestTeam_ReadOne(t *testing.T) {
})
t.Run("invalid id", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ID: -1}
err := team.ReadOne()
err := team.ReadOne(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ID: 99999}
err := team.ReadOne()
err := team.ReadOne(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
})
@ -83,23 +100,31 @@ func TestTeam_ReadOne(t *testing.T) {
func TestTeam_ReadAll(t *testing.T) {
doer := &user.User{ID: 1}
t.Run("normal", func(t *testing.T) {
s := db.NewSession()
defer s.Close()
team := &Team{}
ts, _, _, err := team.ReadAll(doer, "", 1, 50)
teams, _, _, err := team.ReadAll(s, doer, "", 1, 50)
assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(ts).Kind(), reflect.Slice)
s := reflect.ValueOf(ts)
assert.Equal(t, 8, s.Len())
assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice)
ts := reflect.ValueOf(teams)
assert.Equal(t, 8, ts.Len())
})
}
func TestTeam_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{
ID: 1,
Name: "SomethingNew",
}
err := team.Update()
err := team.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "teams", map[string]interface{}{
"id": team.ID,
@ -108,21 +133,27 @@ func TestTeam_Update(t *testing.T) {
})
t.Run("empty name", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{
ID: 1,
Name: "",
}
err := team.Update()
err := team.Update(s)
assert.Error(t, err)
assert.True(t, IsErrTeamNameCannotBeEmpty(err))
})
t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{
ID: 9999,
Name: "SomethingNew",
}
err := team.Update()
err := team.Update(s)
assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err))
})
@ -131,10 +162,15 @@ func TestTeam_Update(t *testing.T) {
func TestTeam_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{
ID: 1,
}
err := team.Delete()
err := team.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "teams", map[string]interface{}{
"id": 1,

View File

@ -16,7 +16,10 @@
package models
import "code.vikunja.io/api/pkg/files"
import (
"code.vikunja.io/api/pkg/files"
"xorm.io/xorm"
)
// Unsplash requires us to do pingbacks to their site and also name the image author.
// To do this properly, we need to save these information somewhere.
@ -36,15 +39,15 @@ func (u *UnsplashPhoto) TableName() string {
}
// Save persists an unsplash photo to the db
func (u *UnsplashPhoto) Save() error {
_, err := x.Insert(u)
func (u *UnsplashPhoto) Save(s *xorm.Session) error {
_, err := s.Insert(u)
return err
}
// GetUnsplashPhotoByFileID returns an unsplash photo by its saved file id
func GetUnsplashPhotoByFileID(fileID int64) (u *UnsplashPhoto, err error) {
func GetUnsplashPhotoByFileID(s *xorm.Session, fileID int64) (u *UnsplashPhoto, err error) {
u = &UnsplashPhoto{}
exists, err := x.Where("file_id = ?", fileID).Get(u)
exists, err := s.Where("file_id = ?", fileID).Get(u)
if err != nil {
return
}
@ -55,10 +58,10 @@ func GetUnsplashPhotoByFileID(fileID int64) (u *UnsplashPhoto, err error) {
}
// RemoveUnsplashPhoto removes an unsplash photo from the db
func RemoveUnsplashPhoto(fileID int64) (err error) {
func RemoveUnsplashPhoto(s *xorm.Session, fileID int64) (err error) {
// This is intentionally "fire and forget" which is why we don't check if we have an
// unsplash entry for that file at all. If there is one, it will be deleted.
// We do this to keep the function simple.
_, err = x.Where("file_id = ?", fileID).Delete(&UnsplashPhoto{})
_, err = s.Where("file_id = ?", fileID).Delete(&UnsplashPhoto{})
return
}

View File

@ -20,6 +20,7 @@ package models
import (
"code.vikunja.io/api/pkg/user"
"xorm.io/builder"
"xorm.io/xorm"
)
// ListUIDs hold all kinds of user IDs from accounts who have somehow access to a list
@ -33,11 +34,11 @@ type ListUIDs struct {
}
// ListUsersFromList returns a list with all users who have access to a list, regardless of the method which gave them access
func ListUsersFromList(l *List, search string) (users []*user.User, err error) {
func ListUsersFromList(s *xorm.Session, l *List, search string) (users []*user.User, err error) {
userids := []*ListUIDs{}
err = x.
err = s.
Select(`l.owner_id as listOwner,
un.user_id as unID,
ul.user_id as ulID,
@ -97,7 +98,7 @@ func ListUsersFromList(l *List, search string) (users []*user.User, err error) {
}
// Get all users
err = x.
err = s.
Table("users").
Select("*").
In("id", uids).

View File

@ -201,8 +201,10 @@ func TestListUsersFromList(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
gotUsers, err := ListUsersFromList(tt.args.l, tt.args.search)
gotUsers, err := ListUsersFromList(s, tt.args.l, tt.args.search)
if (err != nil) != tt.wantErr {
t.Errorf("ListUsersFromList() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -23,6 +23,9 @@ import (
"net/http"
"time"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth"
@ -130,8 +133,17 @@ func HandleCallback(c echo.Context) error {
return err
}
s := db.NewSession()
defer s.Close()
// Check if we have seen this user before
u, err := getOrCreateUser(cl, idToken.Issuer, idToken.Subject)
u, err := getOrCreateUser(s, cl, idToken.Issuer, idToken.Subject)
if err != nil {
_ = s.Rollback()
return err
}
err = s.Commit()
if err != nil {
return err
}
@ -140,9 +152,9 @@ func HandleCallback(c echo.Context) error {
return auth.NewUserAuthTokenResponse(u, c)
}
func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err error) {
func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *user.User, err error) {
// Check if the user exists for that issuer and subject
u, err = user.GetUserWithEmail(&user.User{
u, err = user.GetUserWithEmail(s, &user.User{
Issuer: issuer,
Subject: subject,
})
@ -165,7 +177,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
uu.Username = petname.Generate(3, "-")
}
u, err = user.CreateUser(uu)
u, err = user.CreateUser(s, uu)
if err != nil && !user.IsErrUsernameExists(err) {
return nil, err
}
@ -173,14 +185,14 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
// If their preferred username is already taken, create some random one from the email and subject
if user.IsErrUsernameExists(err) {
uu.Username = petname.Generate(3, "-")
u, err = user.CreateUser(uu)
u, err = user.CreateUser(s, uu)
if err != nil {
return nil, err
}
}
// And create its namespace
err = models.CreateNewNamespaceForUser(u)
err = models.CreateNewNamespaceForUser(s, u)
if err != nil {
return nil, err
}
@ -196,7 +208,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
if cl.Name != u.Name {
u.Name = cl.Name
}
u, err = user.UpdateUser(&user.User{
u, err = user.UpdateUser(s, &user.User{
ID: u.ID,
Email: u.Email,
Name: u.Name,

View File

@ -26,12 +26,18 @@ import (
func TestGetOrCreateUser(t *testing.T) {
t.Run("new user", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "test@example.com",
PreferredUsername: "someUserWhoDoesNotExistYet",
}
u, err := getOrCreateUser(cl, "https://some.issuer", "12345")
u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID,
"email": cl.Email,
@ -40,13 +46,19 @@ func TestGetOrCreateUser(t *testing.T) {
})
t.Run("new user, no username provided", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "test@example.com",
PreferredUsername: "",
}
u, err := getOrCreateUser(cl, "https://some.issuer", "12345")
u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.NoError(t, err)
assert.NotEmpty(t, u.Username)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID,
"email": cl.Email,
@ -54,19 +66,28 @@ func TestGetOrCreateUser(t *testing.T) {
})
t.Run("new user, no email address", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "",
}
_, err := getOrCreateUser(cl, "https://some.issuer", "12345")
_, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.Error(t, err)
})
t.Run("existing user, different email address", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{
Email: "other-email-address@some.service.com",
}
u, err := getOrCreateUser(cl, "https://some.service.com", "12345")
u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345")
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID,
"email": cl.Email,

View File

@ -19,6 +19,7 @@ package background
import (
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/web"
"xorm.io/xorm"
)
// Image represents an image which can be used as a list background
@ -33,7 +34,7 @@ type Image struct {
// Provider represents something that is able to get a list of images and set one of them as background
type Provider interface {
// Search is used to either return a pre-defined list of Image or let the user search for an image
Search(search string, page int64) (result []*Image, err error)
Search(s *xorm.Session, search string, page int64) (result []*Image, err error)
// Set sets an image which was most likely previously obtained by Search as list background
Set(image *Image, list *models.List, auth web.Auth) (err error)
Set(s *xorm.Session, image *Image, list *models.List, auth web.Auth) (err error)
}

View File

@ -22,6 +22,9 @@ import (
"strconv"
"strings"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
@ -59,8 +62,17 @@ func (bp *BackgroundProvider) SearchBackgrounds(c echo.Context) error {
}
}
result, err := p.Search(search, page)
s := db.NewSession()
defer s.Close()
result, err := p.Search(s, search, page)
if err != nil {
_ = s.Rollback()
return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error())
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error())
}
@ -68,7 +80,7 @@ func (bp *BackgroundProvider) SearchBackgrounds(c echo.Context) error {
}
// This function does all kinds of preparations for setting and uploading a background
func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *models.List, auth web.Auth, err error) {
func (bp *BackgroundProvider) setBackgroundPreparations(s *xorm.Session, c echo.Context) (list *models.List, auth web.Auth, err error) {
auth, err = auth2.GetAuthFromClaims(c)
if err != nil {
return nil, nil, echo.NewHTTPError(http.StatusBadRequest, "Invalid auth token: "+err.Error())
@ -81,7 +93,7 @@ func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *m
// Check if the user has the right to change the list background
list = &models.List{ID: listID}
can, err := list.CanUpdate(auth)
can, err := list.CanUpdate(s, auth)
if err != nil {
return
}
@ -90,14 +102,18 @@ func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *m
return list, auth, models.ErrGenericForbidden{}
}
// Load the list
err = list.GetSimpleByID()
list, err = models.GetListSimpleByID(s, list.ID)
return
}
// SetBackground sets an Image as list background
func (bp *BackgroundProvider) SetBackground(c echo.Context) error {
list, auth, err := bp.setBackgroundPreparations(c)
s := db.NewSession()
defer s.Close()
list, auth, err := bp.setBackgroundPreparations(s, c)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -106,11 +122,13 @@ func (bp *BackgroundProvider) SetBackground(c echo.Context) error {
image := &background.Image{}
err = c.Bind(image)
if err != nil {
_ = s.Rollback()
return echo.NewHTTPError(http.StatusBadRequest, "No or invalid model provided: "+err.Error())
}
err = p.Set(image, list, auth)
err = p.Set(s, image, list, auth)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, list)
@ -118,8 +136,12 @@ func (bp *BackgroundProvider) SetBackground(c echo.Context) error {
// UploadBackground uploads a background and passes the id of the uploaded file as an Image to the Set function of the BackgroundProvider.
func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
list, auth, err := bp.setBackgroundPreparations(c)
s := db.NewSession()
defer s.Close()
list, auth, err := bp.setBackgroundPreparations(s, c)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -128,10 +150,12 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
// Get + upload the image
file, err := c.FormFile("background")
if err != nil {
_ = s.Rollback()
return err
}
src, err := file.Open()
if err != nil {
_ = s.Rollback()
return err
}
defer src.Close()
@ -139,9 +163,11 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
// Validate we're dealing with an image
mime, err := mimetype.DetectReader(src)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if !strings.HasPrefix(mime.String(), "image") {
_ = s.Rollback()
return c.JSON(http.StatusBadRequest, models.Message{Message: "Uploaded file is no image."})
}
_, _ = src.Seek(0, io.SeekStart)
@ -149,6 +175,7 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
// Save the file
f, err := files.CreateWithMime(src, file.Filename, uint64(file.Size), auth, mime.String())
if err != nil {
_ = s.Rollback()
if files.IsErrFileIsTooLarge(err) {
return echo.ErrBadRequest
}
@ -158,10 +185,17 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
image := &background.Image{ID: strconv.FormatInt(f.ID, 10)}
err = p.Set(image, list, auth)
err = p.Set(s, image, list, auth)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, list)
}
@ -190,17 +224,23 @@ func GetListBackground(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid list ID: "+err.Error())
}
s := db.NewSession()
defer s.Close()
// Check if a background for this list exists + Rights
list := &models.List{ID: listID}
can, _, err := list.CanRead(auth)
can, _, err := list.CanRead(s, auth)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if !can {
_ = s.Rollback()
log.Infof("Tried to get list background of list %d while not having the rights for it (User: %v)", listID, auth)
return echo.NewHTTPError(http.StatusForbidden)
}
if list.BackgroundFileID == 0 {
_ = s.Rollback()
return echo.NotFoundHandler(c)
}
@ -209,13 +249,19 @@ func GetListBackground(c echo.Context) error {
ID: list.BackgroundFileID,
}
if err := bgFile.LoadFileByID(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Unsplash requires pingbacks as per their api usage guidelines.
// To do this in a privacy-preserving manner, we do the ping from inside of Vikunja to not expose any user details.
// FIXME: This should use an event once we have events
unsplash.Pingback(bgFile)
unsplash.Pingback(s, bgFile)
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Serve the file
return c.Stream(http.StatusOK, "image/jpg", bgFile.File)

View File

@ -26,6 +26,8 @@ import (
"strings"
"time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log"
@ -150,7 +152,7 @@ func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
// @Success 200 {array} background.Image "An array with photos"
// @Failure 500 {object} models.Message "Internal error"
// @Router /backgrounds/unsplash/search [get]
func (p *Provider) Search(search string, page int64) (result []*background.Image, err error) {
func (p *Provider) Search(s *xorm.Session, search string, page int64) (result []*background.Image, err error) {
// If we don't have a search query, return results from the unsplash featured collection
if search == "" {
@ -243,7 +245,7 @@ func (p *Provider) Search(search string, page int64) (result []*background.Image
// @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/backgrounds/unsplash [post]
func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth) (err error) {
func (p *Provider) Set(s *xorm.Session, image *background.Image, list *models.List, auth web.Auth) (err error) {
// Find the photo
photo, err := getUnsplashPhotoInfoByID(image.ID)
@ -292,7 +294,7 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
return err
}
if err := models.RemoveUnsplashPhoto(list.BackgroundFileID); err != nil {
if err := models.RemoveUnsplashPhoto(s, list.BackgroundFileID); err != nil {
return err
}
}
@ -304,7 +306,7 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
Author: photo.User.Username,
AuthorName: photo.User.Name,
}
err = unsplashPhoto.Save()
err = unsplashPhoto.Save(s)
if err != nil {
return
}
@ -315,13 +317,13 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
list.BackgroundInformation = unsplashPhoto
// Set it as the list background
return models.SetListBackground(list.ID, file)
return models.SetListBackground(s, list.ID, file)
}
// Pingback pings the unsplash api if an unsplash photo has been accessed.
func Pingback(f *files.File) {
func Pingback(s *xorm.Session, f *files.File) {
// Check if the file is actually downloaded from unsplash
unsplashPhoto, err := models.GetUnsplashPhotoByFileID(f.ID)
unsplashPhoto, err := models.GetUnsplashPhotoByFileID(s, f.ID)
if err != nil {
if files.IsErrFileIsNotUnsplashFile(err) {
return

View File

@ -19,6 +19,8 @@ package upload
import (
"strconv"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/background"
@ -30,7 +32,7 @@ type Provider struct {
}
// Search is only used to implement the interface
func (p *Provider) Search(search string, page int64) (result []*background.Image, err error) {
func (p *Provider) Search(s *xorm.Session, search string, page int64) (result []*background.Image, err error) {
return
}
@ -50,7 +52,7 @@ func (p *Provider) Search(search string, page int64) (result []*background.Image
// @Failure 404 {object} models.Message "The list does not exist."
// @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/backgrounds/upload [put]
func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth) (err error) {
func (p *Provider) Set(s *xorm.Session, image *background.Image, list *models.List, auth web.Auth) (err error) {
// Remove the old background if one exists
if list.BackgroundFileID != 0 {
file := files.File{ID: list.BackgroundFileID}
@ -67,5 +69,5 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
list.BackgroundInformation = &models.ListBackgroundType{Type: models.ListBackgroundUpload}
return models.SetListBackground(list.ID, file)
return models.SetListBackground(s, list.ID, file)
}

View File

@ -20,6 +20,8 @@ import (
"bytes"
"io/ioutil"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
@ -34,10 +36,14 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
labels := make(map[string]*models.Label)
s := db.NewSession()
defer s.Close()
// Create all namespaces
for _, n := range str {
err = n.Create(user)
err = n.Create(s, user)
if err != nil {
_ = s.Rollback()
return
}
@ -54,8 +60,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
needsDefaultBucket := false
l.NamespaceID = n.ID
err = l.Create(user)
err = l.Create(s, user)
if err != nil {
_ = s.Rollback()
return
}
@ -67,11 +74,13 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
file, err := files.Create(backgroundFile, "", uint64(backgroundFile.Len()), user)
if err != nil {
_ = s.Rollback()
return err
}
err = models.SetListBackground(l.ID, file)
err = models.SetListBackground(s, l.ID, file)
if err != nil {
_ = s.Rollback()
return err
}
@ -87,8 +96,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
oldID := bucket.ID
bucket.ID = 0 // We want a new id
bucket.ListID = l.ID
err = bucket.Create(user)
err = bucket.Create(s, user)
if err != nil {
_ = s.Rollback()
return
}
buckets[oldID] = bucket
@ -111,8 +121,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
}
t.ListID = l.ID
err = t.Create(user)
err = t.Create(s, user)
if err != nil {
_ = s.Rollback()
return
}
@ -132,8 +143,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
// First create the related tasks if they do not exist
if rt.ID == 0 {
rt.ListID = t.ListID
err = rt.Create(user)
err = rt.Create(s, user)
if err != nil {
_ = s.Rollback()
return
}
log.Debugf("[creating structure] Created related task %d", rt.ID)
@ -145,8 +157,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
OtherTaskID: rt.ID,
RelationKind: kind,
}
err = taskRel.Create(user)
err = taskRel.Create(s, user)
if err != nil {
_ = s.Rollback()
return
}
@ -164,8 +177,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
if len(a.File.FileContent) > 0 {
a.TaskID = t.ID
fr := ioutil.NopCloser(bytes.NewReader(a.File.FileContent))
err = a.NewAttachment(fr, a.File.Name, a.File.Size, user)
err = a.NewAttachment(s, fr, a.File.Name, a.File.Size, user)
if err != nil {
_ = s.Rollback()
return
}
log.Debugf("[creating structure] Created new attachment %d", a.ID)
@ -180,8 +194,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
var exists bool
lb, exists = labels[label.Title+label.HexColor]
if !exists {
err = label.Create(user)
err = label.Create(s, user)
if err != nil {
_ = s.Rollback()
return err
}
log.Debugf("[creating structure] Created new label %d", label.ID)
@ -193,8 +208,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
LabelID: lb.ID,
TaskID: t.ID,
}
err = lt.Create(user)
err = lt.Create(s, user)
if err != nil {
_ = s.Rollback()
return err
}
log.Debugf("[creating structure] Associated task %d with label %d", t.ID, lb.ID)
@ -204,13 +220,15 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
// All tasks brought their own bucket with them, therefore the newly created default bucket is just extra space
if !needsDefaultBucket {
b := &models.Bucket{ListID: l.ID}
bucketsIn, _, _, err := b.ReadAll(user, "", 1, 1)
bucketsIn, _, _, err := b.ReadAll(s, user, "", 1, 1)
if err != nil {
_ = s.Rollback()
return err
}
buckets := bucketsIn.([]*models.Bucket)
err = buckets[0].Delete()
err = buckets[0].Delete(s)
if err != nil {
_ = s.Rollback()
return err
}
}
@ -222,5 +240,5 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
log.Debugf("[creating structure] Done inserting new task structure")
return nil
return s.Commit()
}

View File

@ -19,20 +19,10 @@ package migration
import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"xorm.io/xorm"
)
var x *xorm.Engine
// InitDB sets up the database connection to use in this module
func InitDB() (err error) {
x, err = db.CreateDBEngine()
if err != nil {
log.Criticalf("Could not connect to db: %v", err.Error())
return
}
// Cache
if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" {
db.RegisterTableStructsForCache(GetTables())

View File

@ -19,6 +19,7 @@ package migration
import (
"time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/user"
)
@ -37,17 +38,26 @@ func (s *Status) TableName() string {
// SetMigrationStatus sets the migration status for a user
func SetMigrationStatus(m Migrator, u *user.User) (err error) {
s := db.NewSession()
defer s.Close()
status := &Status{
UserID: u.ID,
MigratorName: m.Name(),
}
_, err = x.Insert(status)
_, err = s.Insert(status)
return
}
// GetMigrationStatus returns the migration status for a migration and a user
func GetMigrationStatus(m Migrator, u *user.User) (status *Status, err error) {
s := db.NewSession()
defer s.Close()
status = &Status{}
_, err = x.Where("user_id = ? and migrator_name = ?", u.ID, m.Name()).Desc("id").Get(status)
_, err = s.
Where("user_id = ? and migrator_name = ?", u.ID, m.Name()).
Desc("id").
Get(status)
return
}

View File

@ -17,6 +17,7 @@
package v1
import (
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
@ -56,8 +57,11 @@ func GetAvatar(c echo.Context) error {
// Get the username
username := c.Param("username")
s := db.NewSession()
defer s.Close()
// Get the user
u, err := user.GetUserWithEmail(&user.User{Username: username})
u, err := user.GetUserWithEmail(s, &user.User{Username: username})
if err != nil {
log.Errorf("Error getting user for avatar: %v", err)
return handler.HandleHTTPError(err, c)
@ -113,22 +117,28 @@ func GetAvatar(c echo.Context) error {
// @Router /user/settings/avatar/upload [put]
func UploadAvatar(c echo.Context) (err error) {
s := db.NewSession()
defer s.Close()
uc, err := user.GetCurrentUser(c)
if err != nil {
return handler.HandleHTTPError(err, c)
}
u, err := user.GetUserByID(uc.ID)
u, err := user.GetUserByID(s, uc.ID)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Get + upload the image
file, err := c.FormFile("avatar")
if err != nil {
_ = s.Rollback()
return err
}
src, err := file.Open()
if err != nil {
_ = s.Rollback()
return err
}
defer src.Close()
@ -136,6 +146,7 @@ func UploadAvatar(c echo.Context) (err error) {
// Validate we're dealing with an image
mime, err := mimetype.DetectReader(src)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if !strings.HasPrefix(mime.String(), "image") {
@ -148,6 +159,7 @@ func UploadAvatar(c echo.Context) (err error) {
f := &files.File{ID: u.AvatarFileID}
if err := f.Delete(); err != nil {
if !files.IsErrFileDoesNotExist(err) {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
}
@ -157,11 +169,13 @@ func UploadAvatar(c echo.Context) (err error) {
// Resize the new file to a max height of 1024
img, _, err := image.Decode(src)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
resizedImg := imaging.Resize(img, 0, 1024, imaging.Lanczos)
buf := &bytes.Buffer{}
if err := png.Encode(buf, resizedImg); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -170,6 +184,7 @@ func UploadAvatar(c echo.Context) (err error) {
// Save the file
f, err := files.CreateWithMime(buf, file.Filename, uint64(file.Size), u, "image/png")
if err != nil {
_ = s.Rollback()
if files.IsErrFileIsTooLarge(err) {
return echo.ErrBadRequest
}
@ -180,7 +195,13 @@ func UploadAvatar(c echo.Context) (err error) {
u.AvatarFileID = f.ID
u.AvatarProvider = "upload"
if _, err := user.UpdateUser(u); err != nil {
if _, err := user.UpdateUser(s, u); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth"
"code.vikunja.io/web/handler"
@ -45,8 +47,18 @@ type LinkShareToken struct {
// @Router /shares/{share}/auth [post]
func AuthenticateLinkShare(c echo.Context) error {
hash := c.Param("share")
share, err := models.GetLinkShareByHash(hash)
s := db.NewSession()
defer s.Close()
share, err := models.GetLinkShareByHash(s, hash)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -20,6 +20,9 @@ import (
"net/http"
"strconv"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler"
@ -41,8 +44,11 @@ import (
// @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/lists [get]
func GetListsByNamespaceID(c echo.Context) error {
s := db.NewSession()
defer s.Close()
// Get our namespace
namespace, err := getNamespace(c)
namespace, err := getNamespace(s, c)
if err != nil {
return handler.HandleHTTPError(err, c)
}
@ -53,14 +59,14 @@ func GetListsByNamespaceID(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
lists, err := models.GetListsByNamespaceID(namespace.ID, doer)
lists, err := models.GetListsByNamespaceID(s, namespace.ID, doer)
if err != nil {
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, lists)
}
func getNamespace(c echo.Context) (namespace *models.Namespace, err error) {
func getNamespace(s *xorm.Session, c echo.Context) (namespace *models.Namespace, err error) {
// Check if we have our ID
id := c.Param("namespace")
// Make int
@ -75,12 +81,12 @@ func getNamespace(c echo.Context) (namespace *models.Namespace, err error) {
}
// Check if the user has acces to that namespace
user, err := user.GetCurrentUser(c)
u, err := user.GetCurrentUser(c)
if err != nil {
return
}
namespace = &models.Namespace{ID: namespaceID}
canRead, _, err := namespace.CanRead(user)
canRead, _, err := namespace.CanRead(s, u)
if err != nil {
return namespace, err
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth"
user2 "code.vikunja.io/api/pkg/user"
@ -45,27 +47,38 @@ func Login(c echo.Context) error {
return c.JSON(http.StatusBadRequest, models.Message{Message: "Please provide a username and password."})
}
s := db.NewSession()
defer s.Close()
// Check user
user, err := user2.CheckUserCredentials(&u)
user, err := user2.CheckUserCredentials(s, &u)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
totpEnabled, err := user2.TOTPEnabledForUser(user)
totpEnabled, err := user2.TOTPEnabledForUser(s, user)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if totpEnabled {
_, err = user2.ValidateTOTPPasscode(&user2.TOTPPasscode{
_, err = user2.ValidateTOTPPasscode(s, &user2.TOTPPasscode{
User: user,
Passcode: u.TOTPPasscode,
})
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Create token
return auth.NewUserAuthTokenResponse(user, c)
}
@ -82,18 +95,23 @@ func Login(c echo.Context) error {
// @Router /user/token [post]
func RenewToken(c echo.Context) (err error) {
s := db.NewSession()
defer s.Close()
jwtinf := c.Get("user").(*jwt.Token)
claims := jwtinf.Claims.(jwt.MapClaims)
typ := int(claims["type"].(float64))
if typ == auth.AuthTypeLinkShare {
share := &models.LinkSharing{}
share.ID = int64(claims["id"].(float64))
err := share.ReadOne()
err := share.ReadOne(s)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
t, err := auth.NewLinkShareJWTAuthtoken(share)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, auth.Token{Token: t})
@ -101,11 +119,18 @@ func RenewToken(c echo.Context) (err error) {
u, err := user2.GetUserFromClaims(claims)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
auth2 "code.vikunja.io/api/pkg/modules/auth"
"code.vikunja.io/web/handler"
@ -52,8 +54,12 @@ func UploadTaskAttachment(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
can, err := taskAttachment.CanCreate(auth)
s := db.NewSession()
defer s.Close()
can, err := taskAttachment.CanCreate(s, auth)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if !can {
@ -63,6 +69,7 @@ func UploadTaskAttachment(c echo.Context) error {
// Multipart form
form, err := c.MultipartForm()
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -85,7 +92,7 @@ func UploadTaskAttachment(c echo.Context) error {
}
defer f.Close()
err = ta.NewAttachment(f, file.Filename, uint64(file.Size), auth)
err = ta.NewAttachment(s, f, file.Filename, uint64(file.Size), auth)
if err != nil {
r.Errors = append(r.Errors, handler.HandleHTTPError(err, c))
continue
@ -93,6 +100,11 @@ func UploadTaskAttachment(c echo.Context) error {
r.Success = append(r.Success, ta)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, r)
}
@ -121,8 +133,13 @@ func GetTaskAttachment(c echo.Context) error {
if err != nil {
return handler.HandleHTTPError(err, c)
}
can, _, err := taskAttachment.CanRead(auth)
s := db.NewSession()
defer s.Close()
can, _, err := taskAttachment.CanRead(s, auth)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if !can {
@ -130,14 +147,21 @@ func GetTaskAttachment(c echo.Context) error {
}
// Get the attachment incl file
err = taskAttachment.ReadOne()
err = taskAttachment.ReadOne(s)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Open an send the file to the client
err = taskAttachment.File.LoadFileByID()
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler"
@ -43,8 +45,17 @@ func UserConfirmEmail(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "No token provided.")
}
err := user.ConfirmEmail(&emailConfirm)
s := db.NewSession()
defer s.Close()
err := user.ConfirmEmail(s, &emailConfirm)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -20,6 +20,8 @@ import (
"net/http"
"strconv"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
auth2 "code.vikunja.io/api/pkg/modules/auth"
"code.vikunja.io/api/pkg/user"
@ -40,9 +42,19 @@ import (
// @Failure 500 {object} models.Message "Internal server error."
// @Router /users [get]
func UserList(c echo.Context) error {
s := c.QueryParam("s")
users, err := user.ListUsers(s)
search := c.QueryParam("s")
s := db.NewSession()
defer s.Close()
users, err := user.ListUsers(s, search)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -80,17 +92,27 @@ func ListUsersForList(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
canRead, _, err := list.CanRead(auth)
s := db.NewSession()
defer s.Close()
canRead, _, err := list.CanRead(s, auth)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if !canRead {
return echo.ErrForbidden
}
s := c.QueryParam("s")
users, err := models.ListUsersFromList(&list, s)
search := c.QueryParam("s")
users, err := models.ListUsersFromList(s, &list, search)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler"
@ -43,8 +45,17 @@ func UserResetPassword(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "No password provided.")
}
err := user.ResetPassword(&pwReset)
s := db.NewSession()
defer s.Close()
err := user.ResetPassword(s, &pwReset)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -73,8 +84,17 @@ func UserRequestResetPasswordToken(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err)
}
err := user.RequestUserPasswordResetTokenByEmail(&pwTokenReset)
s := db.NewSession()
defer s.Close()
err := user.RequestUserPasswordResetTokenByEmail(s, &pwTokenReset)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
@ -50,15 +52,25 @@ func RegisterUser(c echo.Context) error {
return c.JSON(http.StatusBadRequest, models.Message{Message: "No or invalid user model provided."})
}
s := db.NewSession()
defer s.Close()
// Insert the user
newUser, err := user.CreateUser(datUser.APIFormat())
newUser, err := user.CreateUser(s, datUser.APIFormat())
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Add its namespace
err = models.CreateNewNamespaceForUser(newUser)
err = models.CreateNewNamespaceForUser(s, newUser)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler"
@ -57,8 +59,17 @@ func GetUserAvatarProvider(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
s := db.NewSession()
defer s.Close()
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -91,15 +102,25 @@ func ChangeUserAvatarProvider(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
s := db.NewSession()
defer s.Close()
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
user.AvatarProvider = uap.AvatarProvider
_, err = user2.UpdateUser(user)
_, err = user2.UpdateUser(s, user)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -129,16 +150,26 @@ func UpdateGeneralUserSettings(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID})
s := db.NewSession()
defer s.Close()
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
user.Name = us.Name
user.EmailRemindersEnabled = us.EmailRemindersEnabled
_, err = user2.UpdateUser(user)
_, err = user2.UpdateUser(s, user)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler"
"github.com/labstack/echo/v4"
@ -41,8 +43,17 @@ func UserShow(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, "Error getting current user.")
}
user, err := user2.GetUserByID(userInfos.ID)
s := db.NewSession()
defer s.Close()
user, err := user2.GetUserByID(s, userInfos.ID)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -22,6 +22,8 @@ import (
"image/jpeg"
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
@ -47,8 +49,17 @@ func UserTOTPEnroll(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
t, err := user.EnrollTOTP(u)
s := db.NewSession()
defer s.Close()
t, err := user.EnrollTOTP(s, u)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -86,8 +97,17 @@ func UserTOTPEnable(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.")
}
err = user.EnableTOTP(passcode)
s := db.NewSession()
defer s.Close()
err = user.EnableTOTP(s, passcode)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -122,18 +142,29 @@ func UserTOTPDisable(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
u, err = user.GetUserByID(u.ID)
s := db.NewSession()
defer s.Close()
u, err = user.GetUserByID(s, u.ID)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
err = user.CheckUserPassword(u, login.Password)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
err = user.DisableTOTP(u)
err = user.DisableTOTP(s, u)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -156,14 +187,24 @@ func UserTOTPQrCode(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
qrcode, err := user.GetTOTPQrCodeForUser(u)
s := db.NewSession()
defer s.Close()
qrcode, err := user.GetTOTPQrCodeForUser(s, u)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
buff := &bytes.Buffer{}
err = jpeg.Encode(buff, qrcode, nil)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
@ -186,8 +227,17 @@ func UserTOTP(c echo.Context) error {
return handler.HandleHTTPError(err, c)
}
t, err := user.GetTOTPForUser(u)
s := db.NewSession()
defer s.Close()
t, err := user.GetTOTPForUser(s, u)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -20,6 +20,8 @@ import (
"fmt"
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
@ -56,16 +58,26 @@ func UpdateUserEmail(c echo.Context) (err error) {
return handler.HandleHTTPError(err, c)
}
emailUpdate.User, err = user.CheckUserCredentials(&user.Login{
s := db.NewSession()
defer s.Close()
emailUpdate.User, err = user.CheckUserCredentials(s, &user.Login{
Username: emailUpdate.User.Username,
Password: emailUpdate.Password,
})
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
err = user.UpdateEmail(emailUpdate)
err = user.UpdateEmail(s, emailUpdate)
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -19,6 +19,8 @@ package v1
import (
"net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler"
@ -61,13 +63,23 @@ func UserChangePassword(c echo.Context) error {
return handler.HandleHTTPError(user.ErrEmptyOldPassword{}, c)
}
s := db.NewSession()
defer s.Close()
// Check the current password
if _, err = user.CheckUserCredentials(&user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil {
if _, err = user.CheckUserCredentials(s, &user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Update the password
if err = user.UpdateUserPassword(doer, newPW.NewPassword); err != nil {
if err = user.UpdateUserPassword(s, doer, newPW.NewPassword); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}

View File

@ -21,6 +21,8 @@ import (
"strings"
"time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
user2 "code.vikunja.io/api/pkg/user"
@ -90,9 +92,16 @@ func (vcls *VikunjaCaldavListStorage) GetResources(rpath string, withChildren bo
return []data.Resource{r}, nil
}
s := db.NewSession()
defer s.Close()
// Otherwise get all lists
thelists, _, _, err := vcls.list.ReadAll(vcls.user, "", -1, 50)
thelists, _, _, err := vcls.list.ReadAll(s, vcls.user, "", -1, 50)
if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err
}
lists := thelists.([]*models.List)
@ -125,10 +134,17 @@ func (vcls *VikunjaCaldavListStorage) GetResourcesByList(rpaths []string) ([]dat
uids = append(uids, string(uid[:endlen]))
}
s := db.NewSession()
defer s.Close()
// GetTasksByUIDs...
// Parse these into ressources...
tasks, err := models.GetTasksByUIDs(uids)
tasks, err := models.GetTasksByUIDs(s, uids)
if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err
}
@ -187,15 +203,22 @@ func (vcls *VikunjaCaldavListStorage) GetResource(rpath string) (*data.Resource,
// If the task is not nil, we need to get the task and not the list
if vcls.task != nil {
s := db.NewSession()
defer s.Close()
// save and override the updated unix date to not break any later etag checks
updated := vcls.task.Updated
task, err := models.GetTaskSimple(&models.Task{ID: vcls.task.ID, UID: vcls.task.UID})
task, err := models.GetTaskSimple(s, &models.Task{ID: vcls.task.ID, UID: vcls.task.UID})
if err != nil {
_ = s.Rollback()
if models.IsErrTaskDoesNotExist(err) {
return nil, false, errs.ResourceNotFoundError
}
return nil, false, err
}
if err := s.Commit(); err != nil {
return nil, false, err
}
vcls.task = &task
if updated.Unix() > 0 {
@ -230,6 +253,9 @@ func (vcls *VikunjaCaldavListStorage) GetShallowResource(rpath string) (*data.Re
// CreateResource creates a new resource
func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*data.Resource, error) {
s := db.NewSession()
defer s.Close()
vTask, err := parseTaskFromVTODO(content)
if err != nil {
return nil, err
@ -238,7 +264,7 @@ func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*da
vTask.ListID = vcls.list.ID
// Check the rights
canCreate, err := vTask.CanCreate(vcls.user)
canCreate, err := vTask.CanCreate(s, vcls.user)
if err != nil {
return nil, err
}
@ -247,8 +273,13 @@ func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*da
}
// Create the task
err = vTask.Create(vcls.user)
err = vTask.Create(s, vcls.user)
if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err
}
@ -272,18 +303,28 @@ func (vcls *VikunjaCaldavListStorage) UpdateResource(rpath, content string) (*da
// At this point, we already have the right task in vcls.task, so we can use that ID directly
vTask.ID = vcls.task.ID
s := db.NewSession()
defer s.Close()
// Check the rights
canUpdate, err := vTask.CanUpdate(vcls.user)
canUpdate, err := vTask.CanUpdate(s, vcls.user)
if err != nil {
_ = s.Rollback()
return nil, err
}
if !canUpdate {
_ = s.Rollback()
return nil, errs.ForbiddenError
}
// Update the task
err = vTask.Update()
err = vTask.Update(s)
if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err
}
@ -299,9 +340,13 @@ func (vcls *VikunjaCaldavListStorage) UpdateResource(rpath, content string) (*da
// DeleteResource deletes a resource
func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error {
if vcls.task != nil {
s := db.NewSession()
defer s.Close()
// Check the rights
canDelete, err := vcls.task.CanDelete(vcls.user)
canDelete, err := vcls.task.CanDelete(s, vcls.user)
if err != nil {
_ = s.Rollback()
return err
}
if !canDelete {
@ -309,7 +354,13 @@ func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error {
}
// Delete it
return vcls.task.Delete()
err = vcls.task.Delete(s)
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
}
return nil
@ -385,16 +436,22 @@ func (vlra *VikunjaListResourceAdapter) GetModTime() time.Time {
}
func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr VikunjaListResourceAdapter, err error) {
can, _, err := vcls.list.CanRead(vcls.user)
s := db.NewSession()
defer s.Close()
can, _, err := vcls.list.CanRead(s, vcls.user)
if err != nil {
_ = s.Rollback()
return
}
if !can {
_ = s.Rollback()
log.Errorf("User %v tried to access a caldav resource (List %v) which they are not allowed to access", vcls.user.Username, vcls.list.ID)
return rr, models.ErrUserDoesNotHaveAccessToList{ListID: vcls.list.ID}
}
err = vcls.list.ReadOne()
err = vcls.list.ReadOne(s)
if err != nil {
_ = s.Rollback()
return
}
@ -403,8 +460,9 @@ func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr Vi
tk := models.TaskCollection{
ListID: vcls.list.ID,
}
iface, _, _, err := tk.ReadAll(vcls.user, "", 1, 1000)
iface, _, _, err := tk.ReadAll(s, vcls.user, "", 1, 1000)
if err != nil {
_ = s.Rollback()
return rr, err
}
tasks, ok := iface.([]*models.Task)
@ -416,6 +474,10 @@ func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr Vi
vcls.list.Tasks = tasks
}
if err := s.Commit(); err != nil {
return rr, err
}
rr = VikunjaListResourceAdapter{
list: vcls.list,
listTasks: listTasks,

View File

@ -50,11 +50,8 @@ import (
"strings"
"time"
microsofttodo "code.vikunja.io/api/pkg/modules/migration/microsoft-todo"
"code.vikunja.io/api/pkg/modules/migration/trello"
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth"
@ -65,7 +62,9 @@ import (
"code.vikunja.io/api/pkg/modules/background/upload"
"code.vikunja.io/api/pkg/modules/migration"
migrationHandler "code.vikunja.io/api/pkg/modules/migration/handler"
microsofttodo "code.vikunja.io/api/pkg/modules/migration/microsoft-todo"
"code.vikunja.io/api/pkg/modules/migration/todoist"
"code.vikunja.io/api/pkg/modules/migration/trello"
"code.vikunja.io/api/pkg/modules/migration/wunderlist"
apiv1 "code.vikunja.io/api/pkg/routes/api/v1"
"code.vikunja.io/api/pkg/routes/caldav"
@ -175,6 +174,7 @@ func NewEcho() *echo.Echo {
})
handler.SetLoggingProvider(log.GetLogger())
handler.SetMaxItemsPerPage(config.ServiceMaxItemsPerPage.GetInt())
handler.SetSessionFactory(db.NewSession)
return e
}
@ -601,11 +601,19 @@ func caldavBasicAuth(username, password string, c echo.Context) (bool, error) {
Username: username,
Password: password,
}
u, err := user.CheckUserCredentials(creds)
s := db.NewSession()
defer s.Close()
u, err := user.CheckUserCredentials(s, creds)
if err != nil {
_ = s.Rollback()
log.Errorf("Error during basic auth for caldav: %v", err)
return false, nil
}
if err := s.Commit(); err != nil {
return false, err
}
// Save the user in echo context for later use
c.Set("userBasicAuth", u)
return true, nil

View File

@ -20,20 +20,10 @@ package user
import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"xorm.io/xorm"
)
var x *xorm.Engine
// InitDB sets up the database connection to use in this module
func InitDB() (err error) {
x, err = db.CreateDBEngine()
if err != nil {
log.Criticalf("Could not connect to db: %v", err.Error())
return
}
// Cache
if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" {
db.RegisterTableStructsForCache(GetTables())

View File

@ -24,8 +24,7 @@ import (
// InitTests handles the actual bootstrapping of the test env
func InitTests() {
var err error
x, err = db.CreateTestEngine()
x, err := db.CreateTestEngine()
if err != nil {
log.Fatal(err)
}

View File

@ -19,6 +19,8 @@ package user
import (
"image"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
@ -47,19 +49,19 @@ type TOTPPasscode struct {
}
// TOTPEnabledForUser checks if totp is enabled for a user - not if it is activated, use GetTOTPForUser to check that.
func TOTPEnabledForUser(user *User) (bool, error) {
func TOTPEnabledForUser(s *xorm.Session, user *User) (bool, error) {
if !config.ServiceEnableTotp.GetBool() {
return false, nil
}
t := &TOTP{}
_, err := x.Where("user_id = ?", user.ID).Get(t)
_, err := s.Where("user_id = ?", user.ID).Get(t)
return t.Enabled, err
}
// GetTOTPForUser returns the current state of totp settings for the user.
func GetTOTPForUser(user *User) (t *TOTP, err error) {
func GetTOTPForUser(s *xorm.Session, user *User) (t *TOTP, err error) {
t = &TOTP{}
exists, err := x.Where("user_id = ?", user.ID).Get(t)
exists, err := s.Where("user_id = ?", user.ID).Get(t)
if err != nil {
return
}
@ -71,8 +73,8 @@ func GetTOTPForUser(user *User) (t *TOTP, err error) {
}
// EnrollTOTP creates a new TOTP entry for the user - it does not enable it yet.
func EnrollTOTP(user *User) (t *TOTP, err error) {
isEnrolled, err := x.Where("user_id = ?", user.ID).Exist(&TOTP{})
func EnrollTOTP(s *xorm.Session, user *User) (t *TOTP, err error) {
isEnrolled, err := s.Where("user_id = ?", user.ID).Exist(&TOTP{})
if err != nil {
return
}
@ -94,18 +96,18 @@ func EnrollTOTP(user *User) (t *TOTP, err error) {
Enabled: false,
URL: key.URL(),
}
_, err = x.Insert(t)
_, err = s.Insert(t)
return
}
// EnableTOTP enables totp for a user. The provided passcode is used to verify the user has a working totp setup.
func EnableTOTP(passcode *TOTPPasscode) (err error) {
t, err := ValidateTOTPPasscode(passcode)
func EnableTOTP(s *xorm.Session, passcode *TOTPPasscode) (err error) {
t, err := ValidateTOTPPasscode(s, passcode)
if err != nil {
return
}
_, err = x.
_, err = s.
Where("id = ?", t.ID).
Cols("enabled").
Update(&TOTP{Enabled: true})
@ -113,14 +115,16 @@ func EnableTOTP(passcode *TOTPPasscode) (err error) {
}
// DisableTOTP removes all totp settings for a user.
func DisableTOTP(user *User) (err error) {
_, err = x.Where("user_id = ?", user.ID).Delete(&TOTP{})
func DisableTOTP(s *xorm.Session, user *User) (err error) {
_, err = s.
Where("user_id = ?", user.ID).
Delete(&TOTP{})
return
}
// ValidateTOTPPasscode validated totp codes of users.
func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) {
t, err = GetTOTPForUser(passcode.User)
func ValidateTOTPPasscode(s *xorm.Session, passcode *TOTPPasscode) (t *TOTP, err error) {
t, err = GetTOTPForUser(s, passcode.User)
if err != nil {
return
}
@ -133,8 +137,8 @@ func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) {
}
// GetTOTPQrCodeForUser returns a qrcode for a user's totp setting
func GetTOTPQrCodeForUser(user *User) (qrcode image.Image, err error) {
t, err := GetTOTPForUser(user)
func GetTOTPQrCodeForUser(s *xorm.Session, user *User) (qrcode image.Image, err error) {
t, err := GetTOTPForUser(s, user)
if err != nil {
return
}

View File

@ -20,6 +20,7 @@ import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/mail"
"code.vikunja.io/api/pkg/utils"
"xorm.io/xorm"
)
// EmailUpdate is the data structure to update a user's email address
@ -32,11 +33,11 @@ type EmailUpdate struct {
}
// UpdateEmail lets a user update their email address
func UpdateEmail(update *EmailUpdate) (err error) {
func UpdateEmail(s *xorm.Session, update *EmailUpdate) (err error) {
// Check the email is not already used
user := &User{}
has, err := x.Where("email = ?", update.NewEmail).Get(user)
has, err := s.Where("email = ?", update.NewEmail).Get(user)
if err != nil {
return
}
@ -46,7 +47,7 @@ func UpdateEmail(update *EmailUpdate) (err error) {
}
// Set the user as unconfirmed and the new email address
update.User, err = GetUserWithEmail(&User{ID: update.User.ID})
update.User, err = GetUserWithEmail(s, &User{ID: update.User.ID})
if err != nil {
return
}
@ -54,7 +55,7 @@ func UpdateEmail(update *EmailUpdate) (err error) {
update.User.IsActive = false
update.User.Email = update.NewEmail
update.User.EmailConfirmToken = utils.MakeRandomString(64)
_, err = x.
_, err = s.
Where("id = ?", update.User.ID).
Cols("email", "is_active", "email_confirm_token").
Update(update.User)

Some files were not shown because too many files have changed in this diff Show More