From 6b40df50d372fb403756756b2d0fc50b2bb7a3ef Mon Sep 17 00:00:00 2001 From: konrad Date: Mon, 31 Dec 2018 01:18:41 +0000 Subject: [PATCH] Add labels to tasks (#45) --- Featurecreep.md | 4 +- REST-Tests/labels.http | 56 + docs/docs.go | 535 +++- docs/errors.md | 3 + docs/swagger/swagger.json | 533 ++++ docs/swagger/swagger.yaml | 360 +++ go.mod | 2 +- pkg/models/error.go | 87 + pkg/models/fixtures/label_task.yml | 3 + pkg/models/fixtures/labels.yml | 12 + pkg/models/fixtures/list.yml | 8 +- pkg/models/fixtures/tasks.yml | 6 + pkg/models/label.go | 59 + pkg/models/label_create_update.go | 87 + pkg/models/label_read.go | 106 + pkg/models/label_rights.go | 83 + pkg/models/label_task.go | 153 + pkg/models/label_task_rights.go | 62 + pkg/models/label_task_test.go | 279 ++ pkg/models/label_test.go | 452 +++ pkg/models/list_tasks.go | 102 +- pkg/models/list_tasks_rights.go | 17 +- pkg/models/models.go | 4 + pkg/routes/routes.go | 20 + .../honnef.co/go/tools/callgraph/callgraph.go | 129 + .../go/tools/callgraph/static/static.go | 35 + vendor/honnef.co/go/tools/callgraph/util.go | 181 ++ .../go/tools/cmd/staticcheck/README.md | 16 + .../go/tools/cmd/staticcheck/staticcheck.go | 23 + .../honnef.co/go/tools/deprecated/stdlib.go | 54 + .../honnef.co/go/tools/functions/concrete.go | 56 + .../honnef.co/go/tools/functions/functions.go | 150 + vendor/honnef.co/go/tools/functions/loops.go | 50 + vendor/honnef.co/go/tools/functions/pure.go | 123 + .../go/tools/functions/terminates.go | 24 + .../go/tools/staticcheck/CONTRIBUTING.md | 15 + .../go/tools/staticcheck/buildtag.go | 21 + vendor/honnef.co/go/tools/staticcheck/lint.go | 2790 +++++++++++++++++ .../honnef.co/go/tools/staticcheck/rules.go | 322 ++ .../go/tools/staticcheck/vrp/channel.go | 73 + .../honnef.co/go/tools/staticcheck/vrp/int.go | 476 +++ .../go/tools/staticcheck/vrp/slice.go | 273 ++ .../go/tools/staticcheck/vrp/string.go | 258 ++ .../honnef.co/go/tools/staticcheck/vrp/vrp.go | 1049 +++++++ vendor/modules.txt | 7 + 45 files changed, 9101 insertions(+), 57 deletions(-) create mode 100644 REST-Tests/labels.http create mode 100644 pkg/models/fixtures/label_task.yml create mode 100644 pkg/models/fixtures/labels.yml create mode 100644 pkg/models/label.go create mode 100644 pkg/models/label_create_update.go create mode 100644 pkg/models/label_read.go create mode 100644 pkg/models/label_rights.go create mode 100644 pkg/models/label_task.go create mode 100644 pkg/models/label_task_rights.go create mode 100644 pkg/models/label_task_test.go create mode 100644 pkg/models/label_test.go create mode 100644 vendor/honnef.co/go/tools/callgraph/callgraph.go create mode 100644 vendor/honnef.co/go/tools/callgraph/static/static.go create mode 100644 vendor/honnef.co/go/tools/callgraph/util.go create mode 100644 vendor/honnef.co/go/tools/cmd/staticcheck/README.md create mode 100644 vendor/honnef.co/go/tools/cmd/staticcheck/staticcheck.go create mode 100644 vendor/honnef.co/go/tools/deprecated/stdlib.go create mode 100644 vendor/honnef.co/go/tools/functions/concrete.go create mode 100644 vendor/honnef.co/go/tools/functions/functions.go create mode 100644 vendor/honnef.co/go/tools/functions/loops.go create mode 100644 vendor/honnef.co/go/tools/functions/pure.go create mode 100644 vendor/honnef.co/go/tools/functions/terminates.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/CONTRIBUTING.md create mode 100644 vendor/honnef.co/go/tools/staticcheck/buildtag.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/lint.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/rules.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/vrp/channel.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/vrp/int.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/vrp/slice.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/vrp/string.go create mode 100644 vendor/honnef.co/go/tools/staticcheck/vrp/vrp.go diff --git a/Featurecreep.md b/Featurecreep.md index 32c3425706f..0ef49d8b7c0 100644 --- a/Featurecreep.md +++ b/Featurecreep.md @@ -100,7 +100,7 @@ Sorry for some of them being in German, I'll tranlate them at some point. * [x] Tasks innerhalb eines definierbarem Bereich, sollte aber trotzdem der server machen, so à la "Gib mir alles für diesen Monat" * [x] Bulk-edit -> Transactions * [x] Assignees -* [ ] Labels +* [x] Labels * [ ] Attachments * [ ] Task-Templates innerhalb namespaces und Listen (-> Mehrere, die auswählbar sind) * [ ] Ein Task muss von mehreren Assignees abgehakt werden bis er als done markiert wird @@ -109,6 +109,8 @@ Sorry for some of them being in German, I'll tranlate them at some point. ### General features * [x] Deps nach mod umziehen +* [ ] Performance bei rechtchecks verbessern + * User & Teamright sollte sich für n rechte in einer Funktion testen lassen * [ ] Globale Limits für anlegbare Listen + Namespaces * [ ] "Smart Lists", Listen nach bestimmten Kriterien gefiltert -> nur UI? * [ ] "Performance-Statistik" -> Wie viele Tasks man in bestimmten Zeiträumen so geschafft hat etc diff --git a/REST-Tests/labels.http b/REST-Tests/labels.http new file mode 100644 index 00000000000..0689769df71 --- /dev/null +++ b/REST-Tests/labels.http @@ -0,0 +1,56 @@ +# Get all labels +GET http://localhost:8080/api/v1/labels +Authorization: Bearer {{auth_token}} + +### +# Add a new label +PUT http://localhost:8080/api/v1/labels +Authorization: Bearer {{auth_token}} +Content-Type: application/json + +{ + "title": "test5" +} + +### +# Delete a label +DELETE http://localhost:8080/api/v1/labels/6 +Authorization: Bearer {{auth_token}} + +### +# Update a label +POST http://localhost:8080/api/v1/labels/1 +Authorization: Bearer {{auth_token}} +Content-Type: application/json + +{ + "title": "testschinkenbrot", + "description": "käsebrot" +} + +### +# Get one label +GET http://localhost:8080/api/v1/labels/1 +Authorization: Bearer {{auth_token}} + +### +# Get all labels on a task +GET http://localhost:8080/api/v1/tasks/3565/labels +Authorization: Bearer {{auth_token}} + +### +# Add a new label to a task +PUT http://localhost:8080/api/v1/tasks/3565/labels +Authorization: Bearer {{auth_token}} +Content-Type: application/json + +{ + "label_id": 1 +} + +### +# Delete a label from a task +DELETE http://localhost:8080/api/v1/tasks/3565/labels/1 +Authorization: Bearer {{auth_token}} + +### \ No newline at end of file diff --git a/docs/docs.go b/docs/docs.go index 855150b4aff..38eca0e4769 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -1,6 +1,6 @@ // GENERATED BY THE COMMAND ABOVE; DO NOT EDIT // This file was generated by swaggo/swag at -// 2018-12-29 15:14:06.225275112 +0100 CET m=+0.295589005 +// 2018-12-30 21:42:08.56057367 +0100 CET m=+0.082542821 package docs @@ -25,6 +25,301 @@ var doc = `{ "host": "{{.Host}}", "basePath": "/api/v1", "paths": { + "/labels": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Returns all labels which are either created by the user or associated with a task the user has at least read-access to.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Get all labels a user has access to", + "parameters": [ + { + "type": "integer", + "description": "The page number. Used for pagination. If not provided, the first page of results is returned.", + "name": "p", + "in": "query" + }, + { + "type": "string", + "description": "Search labels by label text.", + "name": "s", + "in": "query" + } + ], + "responses": { + "200": { + "description": "The labels", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Creates a new label.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Create a label", + "parameters": [ + { + "description": "The label object", + "name": "label", + "in": "body", + "required": true, + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + } + ], + "responses": { + "200": { + "description": "The created label object.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "400": { + "description": "Invalid label object provided.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, + "/labels/{id}": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Returns one label by its ID.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Gets one label", + "parameters": [ + { + "type": "integer", + "description": "Label ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "The label", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "403": { + "description": "The user does not have access to the label", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "404": { + "description": "Label not found", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Update an existing label. The user needs to be the creator of the label to be able to do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Update a label", + "parameters": [ + { + "type": "integer", + "description": "Label ID", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "The label object", + "name": "label", + "in": "body", + "required": true, + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + } + ], + "responses": { + "200": { + "description": "The created label object.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "400": { + "description": "Invalid label object provided.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "403": { + "description": "Not allowed to update the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "404": { + "description": "Label not found.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Delete an existing label. The user needs to be the creator of the label to be able to do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Delete a label", + "parameters": [ + { + "type": "integer", + "description": "Label ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "The label was successfully deleted.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "403": { + "description": "Not allowed to delete the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "404": { + "description": "Label not found.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, "/lists": { "get": { "security": [ @@ -2349,6 +2644,205 @@ var doc = `{ } } }, + "/tasks/{task}/labels": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Returns all labels which are assicociated with a given task.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Get all labels on a task", + "parameters": [ + { + "type": "integer", + "description": "Task ID", + "name": "task", + "in": "path", + "required": true + }, + { + "type": "integer", + "description": "The page number. Used for pagination. If not provided, the first page of results is returned.", + "name": "p", + "in": "query" + }, + { + "type": "string", + "description": "Search labels by label text.", + "name": "s", + "in": "query" + } + ], + "responses": { + "200": { + "description": "The labels", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Add a label to a task. The user needs to have write-access to the list to be able do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Add a label to a task", + "parameters": [ + { + "type": "integer", + "description": "Task ID", + "name": "task", + "in": "path", + "required": true + }, + { + "description": "The label object", + "name": "label", + "in": "body", + "required": true, + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + } + ], + "responses": { + "200": { + "description": "The created label relation object.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "400": { + "description": "Invalid label object provided.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "403": { + "description": "Not allowed to add the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "404": { + "description": "The label does not exist.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, + "/tasks/{task}/labels/{label}": { + "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Remove a label from a task. The user needs to have write-access to the list to be able do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Remove a label from a task", + "parameters": [ + { + "type": "integer", + "description": "Task ID", + "name": "task", + "in": "path", + "required": true + }, + { + "type": "integer", + "description": "Label ID", + "name": "label", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "The label was successfully removed.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "403": { + "description": "Not allowed to remove the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "404": { + "description": "Label not found.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io.web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, "/teams": { "get": { "security": [ @@ -3040,6 +3534,12 @@ var doc = `{ "id": { "type": "integer" }, + "labels": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + }, "listID": { "type": "integer" }, @@ -3089,6 +3589,33 @@ var doc = `{ } } }, + "models.Label": { + "type": "object", + "properties": { + "created": { + "type": "integer" + }, + "created_by": { + "type": "object", + "$ref": "#/definitions/models.User" + }, + "description": { + "type": "string" + }, + "hex_color": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "updated": { + "type": "integer" + } + } + }, "models.List": { "type": "object", "properties": { @@ -3150,6 +3677,12 @@ var doc = `{ "id": { "type": "integer" }, + "labels": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + }, "listID": { "type": "integer" }, diff --git a/docs/errors.md b/docs/errors.md index ce9ed0a39c7..fbcecab7dea 100644 --- a/docs/errors.md +++ b/docs/errors.md @@ -23,6 +23,7 @@ This document describes the different errors Vikunja can return. | 4002 | 404 | The list task does not exist. | | 4003 | 403 | All bulk editing tasks must belong to the same list. | | 4004 | 403 | Need at least one task when bulk editing tasks. | +| 4005 | 403 | The user does not have the right to see the task. | | 5001 | 404 | The namspace does not exist. | | 5003 | 403 | The user does not have access to the specified namespace. | | 5006 | 400 | The namespace name cannot be empty. | @@ -39,3 +40,5 @@ This document describes the different errors Vikunja can return. | 7001 | 400 | The user right is invalid. | | 7002 | 409 | The user already has access to that list. | | 7003 | 403 | The user does not have access to that list. | +| 8001 | 403 | This label already exists on that task. | +| 8002 | 404 | The label does not exist. | \ No newline at end of file diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index bcf7bb02ec5..1a9439dd1ac 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -12,6 +12,301 @@ "host": "{{.Host}}", "basePath": "/api/v1", "paths": { + "/labels": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Returns all labels which are either created by the user or associated with a task the user has at least read-access to.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Get all labels a user has access to", + "parameters": [ + { + "type": "integer", + "description": "The page number. Used for pagination. If not provided, the first page of results is returned.", + "name": "p", + "in": "query" + }, + { + "type": "string", + "description": "Search labels by label text.", + "name": "s", + "in": "query" + } + ], + "responses": { + "200": { + "description": "The labels", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Creates a new label.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Create a label", + "parameters": [ + { + "description": "The label object", + "name": "label", + "in": "body", + "required": true, + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + } + ], + "responses": { + "200": { + "description": "The created label object.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "400": { + "description": "Invalid label object provided.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, + "/labels/{id}": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Returns one label by its ID.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Gets one label", + "parameters": [ + { + "type": "integer", + "description": "Label ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "The label", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "403": { + "description": "The user does not have access to the label", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "404": { + "description": "Label not found", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Update an existing label. The user needs to be the creator of the label to be able to do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Update a label", + "parameters": [ + { + "type": "integer", + "description": "Label ID", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "The label object", + "name": "label", + "in": "body", + "required": true, + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + } + ], + "responses": { + "200": { + "description": "The created label object.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "400": { + "description": "Invalid label object provided.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "403": { + "description": "Not allowed to update the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "404": { + "description": "Label not found.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Delete an existing label. The user needs to be the creator of the label to be able to do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Delete a label", + "parameters": [ + { + "type": "integer", + "description": "Label ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "The label was successfully deleted.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "403": { + "description": "Not allowed to delete the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "404": { + "description": "Label not found.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, "/lists": { "get": { "security": [ @@ -2336,6 +2631,205 @@ } } }, + "/tasks/{task}/labels": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Returns all labels which are assicociated with a given task.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Get all labels on a task", + "parameters": [ + { + "type": "integer", + "description": "Task ID", + "name": "task", + "in": "path", + "required": true + }, + { + "type": "integer", + "description": "The page number. Used for pagination. If not provided, the first page of results is returned.", + "name": "p", + "in": "query" + }, + { + "type": "string", + "description": "Search labels by label text.", + "name": "s", + "in": "query" + } + ], + "responses": { + "200": { + "description": "The labels", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Add a label to a task. The user needs to have write-access to the list to be able do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Add a label to a task", + "parameters": [ + { + "type": "integer", + "description": "Task ID", + "name": "task", + "in": "path", + "required": true + }, + { + "description": "The label object", + "name": "label", + "in": "body", + "required": true, + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + } + ], + "responses": { + "200": { + "description": "The created label relation object.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "400": { + "description": "Invalid label object provided.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "403": { + "description": "Not allowed to add the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "404": { + "description": "The label does not exist.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, + "/tasks/{task}/labels/{label}": { + "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Remove a label from a task. The user needs to have write-access to the list to be able do this.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "labels" + ], + "summary": "Remove a label from a task", + "parameters": [ + { + "type": "integer", + "description": "Task ID", + "name": "task", + "in": "path", + "required": true + }, + { + "type": "integer", + "description": "Label ID", + "name": "label", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "The label was successfully removed.", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Label" + } + }, + "403": { + "description": "Not allowed to remove the label.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "404": { + "description": "Label not found.", + "schema": { + "type": "object", + "$ref": "#/definitions/code.vikunja.io/web.HTTPError" + } + }, + "500": { + "description": "Internal error", + "schema": { + "type": "object", + "$ref": "#/definitions/models.Message" + } + } + } + } + }, "/teams": { "get": { "security": [ @@ -3026,6 +3520,12 @@ "id": { "type": "integer" }, + "labels": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + }, "listID": { "type": "integer" }, @@ -3075,6 +3575,33 @@ } } }, + "models.Label": { + "type": "object", + "properties": { + "created": { + "type": "integer" + }, + "created_by": { + "type": "object", + "$ref": "#/definitions/models.User" + }, + "description": { + "type": "string" + }, + "hex_color": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "updated": { + "type": "integer" + } + } + }, "models.List": { "type": "object", "properties": { @@ -3136,6 +3663,12 @@ "id": { "type": "integer" }, + "labels": { + "type": "array", + "items": { + "$ref": "#/definitions/models.Label" + } + }, "listID": { "type": "integer" }, diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index 354064767ca..fbed2ef956c 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -32,6 +32,10 @@ definitions: type: integer id: type: integer + labels: + items: + $ref: '#/definitions/models.Label' + type: array listID: type: integer parentTaskID: @@ -64,6 +68,24 @@ definitions: token: type: string type: object + models.Label: + properties: + created: + type: integer + created_by: + $ref: '#/definitions/models.User' + type: object + description: + type: string + hex_color: + type: string + id: + type: integer + title: + type: string + updated: + type: integer + type: object models.List: properties: created: @@ -105,6 +127,10 @@ definitions: type: integer id: type: integer + labels: + items: + $ref: '#/definitions/models.Label' + type: array listID: type: integer parentTaskID: @@ -369,6 +395,205 @@ info: title: Vikunja API version: '{{.Version}}' paths: + /labels: + get: + consumes: + - application/json + description: Returns all labels which are either created by the user or associated + with a task the user has at least read-access to. + parameters: + - description: The page number. Used for pagination. If not provided, the first + page of results is returned. + in: query + name: p + type: integer + - description: Search labels by label text. + in: query + name: s + type: string + produces: + - application/json + responses: + "200": + description: The labels + schema: + items: + $ref: '#/definitions/models.Label' + type: array + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Get all labels a user has access to + tags: + - labels + put: + consumes: + - application/json + description: Creates a new label. + parameters: + - description: The label object + in: body + name: label + required: true + schema: + $ref: '#/definitions/models.Label' + type: object + produces: + - application/json + responses: + "200": + description: The created label object. + schema: + $ref: '#/definitions/models.Label' + type: object + "400": + description: Invalid label object provided. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Create a label + tags: + - labels + /labels/{id}: + delete: + consumes: + - application/json + description: Delete an existing label. The user needs to be the creator of the + label to be able to do this. + parameters: + - description: Label ID + in: path + name: id + required: true + type: integer + produces: + - application/json + responses: + "200": + description: The label was successfully deleted. + schema: + $ref: '#/definitions/models.Label' + type: object + "403": + description: Not allowed to delete the label. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "404": + description: Label not found. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Delete a label + tags: + - labels + get: + consumes: + - application/json + description: Returns one label by its ID. + parameters: + - description: Label ID + in: path + name: id + required: true + type: integer + produces: + - application/json + responses: + "200": + description: The label + schema: + $ref: '#/definitions/models.Label' + type: object + "403": + description: The user does not have access to the label + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "404": + description: Label not found + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Gets one label + tags: + - labels + put: + consumes: + - application/json + description: Update an existing label. The user needs to be the creator of the + label to be able to do this. + parameters: + - description: Label ID + in: path + name: id + required: true + type: integer + - description: The label object + in: body + name: label + required: true + schema: + $ref: '#/definitions/models.Label' + type: object + produces: + - application/json + responses: + "200": + description: The created label object. + schema: + $ref: '#/definitions/models.Label' + type: object + "400": + description: Invalid label object provided. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "403": + description: Not allowed to update the label. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "404": + description: Label not found. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Update a label + tags: + - labels /lists: get: consumes: @@ -1750,6 +1975,141 @@ paths: summary: Update a task tags: - task + /tasks/{task}/labels: + get: + consumes: + - application/json + description: Returns all labels which are assicociated with a given task. + parameters: + - description: Task ID + in: path + name: task + required: true + type: integer + - description: The page number. Used for pagination. If not provided, the first + page of results is returned. + in: query + name: p + type: integer + - description: Search labels by label text. + in: query + name: s + type: string + produces: + - application/json + responses: + "200": + description: The labels + schema: + items: + $ref: '#/definitions/models.Label' + type: array + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Get all labels on a task + tags: + - labels + put: + consumes: + - application/json + description: Add a label to a task. The user needs to have write-access to the + list to be able do this. + parameters: + - description: Task ID + in: path + name: task + required: true + type: integer + - description: The label object + in: body + name: label + required: true + schema: + $ref: '#/definitions/models.Label' + type: object + produces: + - application/json + responses: + "200": + description: The created label relation object. + schema: + $ref: '#/definitions/models.Label' + type: object + "400": + description: Invalid label object provided. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "403": + description: Not allowed to add the label. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "404": + description: The label does not exist. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Add a label to a task + tags: + - labels + /tasks/{task}/labels/{label}: + delete: + consumes: + - application/json + description: Remove a label from a task. The user needs to have write-access + to the list to be able do this. + parameters: + - description: Task ID + in: path + name: task + required: true + type: integer + - description: Label ID + in: path + name: label + required: true + type: integer + produces: + - application/json + responses: + "200": + description: The label was successfully removed. + schema: + $ref: '#/definitions/models.Label' + type: object + "403": + description: Not allowed to remove the label. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "404": + description: Label not found. + schema: + $ref: '#/definitions/code.vikunja.io/web.HTTPError' + type: object + "500": + description: Internal error + schema: + $ref: '#/definitions/models.Message' + type: object + security: + - ApiKeyAuth: [] + summary: Remove a label from a task + tags: + - labels /tasks/all: get: consumes: diff --git a/go.mod b/go.mod index c77e821437c..ef0317f0372 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/go-openapi/swag v0.17.2 // indirect github.com/go-redis/redis v6.14.2+incompatible github.com/go-sql-driver/mysql v1.4.1 - github.com/go-xorm/builder v0.0.0-20170519032130-c8871c857d25 // indirect + github.com/go-xorm/builder v0.0.0-20170519032130-c8871c857d25 github.com/go-xorm/core v0.5.8 github.com/go-xorm/tests v0.5.6 // indirect github.com/go-xorm/xorm v0.0.0-20170930012613-29d4a0330a00 diff --git a/pkg/models/error.go b/pkg/models/error.go index a0a2040b359..4f0b59bc475 100644 --- a/pkg/models/error.go +++ b/pkg/models/error.go @@ -474,6 +474,34 @@ func (err ErrBulkTasksNeedAtLeastOne) HTTPError() web.HTTPError { return web.HTTPError{HTTPCode: http.StatusBadRequest, Code: ErrCodeBulkTasksNeedAtLeastOne, Message: "Need at least one tasks to do bulk editing."} } +// ErrNoRightToSeeTask represents an error where a user does not have the right to see a task +type ErrNoRightToSeeTask struct { + TaskID int64 + UserID int64 +} + +// IsErrNoRightToSeeTask checks if an error is ErrNoRightToSeeTask. +func IsErrNoRightToSeeTask(err error) bool { + _, ok := err.(ErrNoRightToSeeTask) + return ok +} + +func (err ErrNoRightToSeeTask) Error() string { + return fmt.Sprintf("User does not have the right to see the task [TaskID: %v, UserID: %v]", err.TaskID, err.UserID) +} + +// ErrCodeNoRightToSeeTask holds the unique world-error code of this error +const ErrCodeNoRightToSeeTask = 4005 + +// HTTPError holds the http error description +func (err ErrNoRightToSeeTask) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusForbidden, + Code: ErrCodeNoRightToSeeTask, + Message: "You don't have the right to see this task.", + } +} + // ================= // Namespace errors // ================= @@ -864,3 +892,62 @@ const ErrCodeUserDoesNotHaveAccessToList = 7003 func (err ErrUserDoesNotHaveAccessToList) HTTPError() web.HTTPError { return web.HTTPError{HTTPCode: http.StatusForbidden, Code: ErrCodeUserDoesNotHaveAccessToList, Message: "This user does not have access to the list."} } + +// ============= +// Label errors +// ============= + +// ErrLabelIsAlreadyOnTask represents an error where a label is already bound to a task +type ErrLabelIsAlreadyOnTask struct { + LabelID int64 + TaskID int64 +} + +// IsErrLabelIsAlreadyOnTask checks if an error is ErrLabelIsAlreadyOnTask. +func IsErrLabelIsAlreadyOnTask(err error) bool { + _, ok := err.(ErrLabelIsAlreadyOnTask) + return ok +} + +func (err ErrLabelIsAlreadyOnTask) Error() string { + return fmt.Sprintf("Label already exists on task [TaskID: %v, LabelID: %v]", err.TaskID, err.LabelID) +} + +// ErrCodeLabelIsAlreadyOnTask holds the unique world-error code of this error +const ErrCodeLabelIsAlreadyOnTask = 8001 + +// HTTPError holds the http error description +func (err ErrLabelIsAlreadyOnTask) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusBadRequest, + Code: ErrCodeLabelIsAlreadyOnTask, + Message: "This label already exists on the task.", + } +} + +// ErrLabelDoesNotExist represents an error where a label does not exist +type ErrLabelDoesNotExist struct { + LabelID int64 +} + +// IsErrLabelDoesNotExist checks if an error is ErrLabelDoesNotExist. +func IsErrLabelDoesNotExist(err error) bool { + _, ok := err.(ErrLabelDoesNotExist) + return ok +} + +func (err ErrLabelDoesNotExist) Error() string { + return fmt.Sprintf("Label does not exist [LabelID: %v]", err.LabelID) +} + +// ErrCodeLabelDoesNotExist holds the unique world-error code of this error +const ErrCodeLabelDoesNotExist = 8002 + +// HTTPError holds the http error description +func (err ErrLabelDoesNotExist) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusNotFound, + Code: ErrCodeLabelDoesNotExist, + Message: "This label does not exist.", + } +} diff --git a/pkg/models/fixtures/label_task.yml b/pkg/models/fixtures/label_task.yml new file mode 100644 index 00000000000..6c644e06a6b --- /dev/null +++ b/pkg/models/fixtures/label_task.yml @@ -0,0 +1,3 @@ +- id: 1 + task_id: 1 + label_id: 4 \ No newline at end of file diff --git a/pkg/models/fixtures/labels.yml b/pkg/models/fixtures/labels.yml new file mode 100644 index 00000000000..4332e9098da --- /dev/null +++ b/pkg/models/fixtures/labels.yml @@ -0,0 +1,12 @@ +- id: 1 + title: 'Label #1' + created_by_id: 1 +- id: 2 + title: 'Label #2' + created_by_id: 1 +- id: 3 + title: 'Label #3 - other user' + created_by_id: 2 +- id: 4 + title: 'Label #4 - visible via other task' + created_by_id: 2 \ No newline at end of file diff --git a/pkg/models/fixtures/list.yml b/pkg/models/fixtures/list.yml index 27fbd05bdcd..b45ee1a8706 100644 --- a/pkg/models/fixtures/list.yml +++ b/pkg/models/fixtures/list.yml @@ -21,4 +21,10 @@ title: Test4 description: Lorem Ipsum owner_id: 3 - namespace_id: 3 \ No newline at end of file + namespace_id: 3 +- + id: 5 + title: Test5 + description: Lorem Ipsum + owner_id: 5 + namespace_id: 5 \ No newline at end of file diff --git a/pkg/models/fixtures/tasks.yml b/pkg/models/fixtures/tasks.yml index 1f5ce52ec73..7ab818c5673 100644 --- a/pkg/models/fixtures/tasks.yml +++ b/pkg/models/fixtures/tasks.yml @@ -84,4 +84,10 @@ created_by_id: 1 list_id: 2 created: 1543626724 + updated: 1543626724 +- id: 14 + text: 'task #14 basic other list' + created_by_id: 5 + list_id: 5 + created: 1543626724 updated: 1543626724 \ No newline at end of file diff --git a/pkg/models/label.go b/pkg/models/label.go new file mode 100644 index 00000000000..164d55ccbd0 --- /dev/null +++ b/pkg/models/label.go @@ -0,0 +1,59 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import ( + "code.vikunja.io/web" +) + +// Label represents a label +type Label struct { + ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id" param:"label"` + Title string `xorm:"varchar(250) not null" json:"title" valid:"runelength(3|250)"` + Description string `xorm:"varchar(250)" json:"description" valid:"runelength(0|250)"` + HexColor string `xorm:"varchar(6)" json:"hex_color" valid:"runelength(0|6)"` + + CreatedByID int64 `xorm:"int(11) not null" json:"-"` + CreatedBy *User `xorm:"-" json:"created_by"` + + Created int64 `xorm:"created" json:"created"` + Updated int64 `xorm:"updated" json:"updated"` + + web.CRUDable `xorm:"-" json:"-"` + web.Rights `xorm:"-" json:"-"` +} + +// TableName makes a pretty table name +func (Label) TableName() string { + return "labels" +} + +// LabelTask represents a relation between a label and a task +type LabelTask struct { + ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id"` + TaskID int64 `xorm:"int(11) INDEX not null" json:"-" param:"listtask"` + LabelID int64 `xorm:"int(11) INDEX not null" json:"label_id" param:"label"` + Created int64 `xorm:"created" json:"created"` + + web.CRUDable `xorm:"-" json:"-"` + web.Rights `xorm:"-" json:"-"` +} + +// TableName makes a pretty table name +func (LabelTask) TableName() string { + return "label_task" +} diff --git a/pkg/models/label_create_update.go b/pkg/models/label_create_update.go new file mode 100644 index 00000000000..3ba1ea347a8 --- /dev/null +++ b/pkg/models/label_create_update.go @@ -0,0 +1,87 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import "code.vikunja.io/web" + +// Create creates a new label +// @Summary Create a label +// @Description Creates a new label. +// @tags labels +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param label body models.Label true "The label object" +// @Success 200 {object} models.Label "The created label object." +// @Failure 400 {object} code.vikunja.io/web.HTTPError "Invalid label object provided." +// @Failure 500 {object} models.Message "Internal error" +// @Router /labels [put] +func (l *Label) Create(a web.Auth) (err error) { + u, err := getUserWithError(a) + if err != nil { + return + } + + l.CreatedBy = u + l.CreatedByID = u.ID + + _, err = x.Insert(l) + return +} + +// Update updates a label +// @Summary Update a label +// @Description Update an existing label. The user needs to be the creator of the label to be able to do this. +// @tags labels +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param id path int true "Label ID" +// @Param label body models.Label true "The label object" +// @Success 200 {object} models.Label "The created label object." +// @Failure 400 {object} code.vikunja.io/web.HTTPError "Invalid label object provided." +// @Failure 403 {object} code.vikunja.io/web.HTTPError "Not allowed to update the label." +// @Failure 404 {object} code.vikunja.io/web.HTTPError "Label not found." +// @Failure 500 {object} models.Message "Internal error" +// @Router /labels/{id} [put] +func (l *Label) Update() (err error) { + _, err = x.ID(l.ID).Update(l) + if err != nil { + return + } + + err = l.ReadOne() + return +} + +// Delete deletes a label +// @Summary Delete a label +// @Description Delete an existing label. The user needs to be the creator of the label to be able to do this. +// @tags labels +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param id path int true "Label ID" +// @Success 200 {object} models.Label "The label was successfully deleted." +// @Failure 403 {object} code.vikunja.io/web.HTTPError "Not allowed to delete the label." +// @Failure 404 {object} code.vikunja.io/web.HTTPError "Label not found." +// @Failure 500 {object} models.Message "Internal error" +// @Router /labels/{id} [delete] +func (l *Label) Delete() (err error) { + _, err = x.ID(l.ID).Delete(&Label{}) + return err +} diff --git a/pkg/models/label_read.go b/pkg/models/label_read.go new file mode 100644 index 00000000000..c5c94efd062 --- /dev/null +++ b/pkg/models/label_read.go @@ -0,0 +1,106 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import ( + "code.vikunja.io/web" + "time" +) + +// ReadAll gets all labels a user can use +// @Summary Get all labels a user has access to +// @Description Returns all labels which are either created by the user or associated with a task the user has at least read-access to. +// @tags labels +// @Accept json +// @Produce json +// @Param p query int false "The page number. Used for pagination. If not provided, the first page of results is returned." +// @Param s query string false "Search labels by label text." +// @Security ApiKeyAuth +// @Success 200 {array} models.Label "The labels" +// @Failure 500 {object} models.Message "Internal error" +// @Router /labels [get] +func (l *Label) ReadAll(search string, a web.Auth, page int) (ls interface{}, err error) { + u, err := getUserWithError(a) + if err != nil { + return nil, err + } + + // Get all tasks + taskIDs, err := getUserTaskIDs(u) + if err != nil { + return nil, err + } + + return getLabelsByTaskIDs(search, u, page, taskIDs, true) +} + +// ReadOne gets one label +// @Summary Gets one label +// @Description Returns one label by its ID. +// @tags labels +// @Accept json +// @Produce json +// @Param id path int true "Label ID" +// @Security ApiKeyAuth +// @Success 200 {object} models.Label "The label" +// @Failure 403 {object} code.vikunja.io/web.HTTPError "The user does not have access to the label" +// @Failure 404 {object} code.vikunja.io/web.HTTPError "Label not found" +// @Failure 500 {object} models.Message "Internal error" +// @Router /labels/{id} [get] +func (l *Label) ReadOne() (err error) { + label, err := getLabelByIDSimple(l.ID) + if err != nil { + return err + } + *l = *label + + user, err := GetUserByID(l.CreatedByID) + if err != nil { + return err + } + + l.CreatedBy = &user + return +} + +func getLabelByIDSimple(labelID int64) (*Label, error) { + label := Label{} + exists, err := x.ID(labelID).Get(&label) + if err != nil { + return &label, err + } + + if !exists { + return &Label{}, ErrLabelDoesNotExist{labelID} + } + return &label, err +} + +// Helper method to get all task ids a user has +func getUserTaskIDs(u *User) (taskIDs []int64, err error) { + tasks, err := GetTasksByUser("", u, -1, SortTasksByUnsorted, time.Unix(0, 0), time.Unix(0, 0)) + if err != nil { + return nil, err + } + + // make a slice of task ids + for _, t := range tasks { + taskIDs = append(taskIDs, t.ID) + } + + return +} diff --git a/pkg/models/label_rights.go b/pkg/models/label_rights.go new file mode 100644 index 00000000000..9df2877f9c4 --- /dev/null +++ b/pkg/models/label_rights.go @@ -0,0 +1,83 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import ( + "code.vikunja.io/api/pkg/log" + "code.vikunja.io/web" + "github.com/go-xorm/builder" +) + +// CanUpdate checks if a user can update a label +func (l *Label) CanUpdate(a web.Auth) bool { + return l.isLabelOwner(a) // Only owners should be allowed to update a label +} + +// CanDelete checks if a user can delete a label +func (l *Label) CanDelete(a web.Auth) bool { + return l.isLabelOwner(a) // Only owners should be allowed to delete a label +} + +// CanRead checks if a user can read a label +func (l *Label) CanRead(a web.Auth) bool { + return l.hasAccessToLabel(a) +} + +// CanCreate checks if the user can create a label +// Currently a dummy. +func (l *Label) CanCreate(a web.Auth) bool { + return true +} + +func (l *Label) isLabelOwner(a web.Auth) bool { + u := getUserForRights(a) + lorig, err := getLabelByIDSimple(l.ID) + if err != nil { + log.Log.Errorf("Error occurred during isLabelOwner for Label: %v", err) + return false + } + return lorig.CreatedByID == u.ID +} + +// Helper method to check if a user can see a specific label +func (l *Label) hasAccessToLabel(a web.Auth) bool { + u := getUserForRights(a) + + // Get all tasks + taskIDs, err := getUserTaskIDs(u) + if err != nil { + log.Log.Errorf("Error occurred during hasAccessToLabel for Label: %v", err) + return false + } + + // Get all labels associated with these tasks + var labels []*Label + has, err := x.Table("labels"). + Select("labels.*"). + Join("LEFT", "label_task", "label_task.label_id = labels.id"). + Where("label_task.label_id != null OR labels.created_by_id = ?", u.ID). + Or(builder.In("label_task.task_id", taskIDs)). + And("labels.id = ?", l.ID). + GroupBy("labels.id"). + Exist(&labels) + if err != nil { + log.Log.Errorf("Error occurred during hasAccessToLabel for Label: %v", err) + return false + } + + return has +} diff --git a/pkg/models/label_task.go b/pkg/models/label_task.go new file mode 100644 index 00000000000..0ee3b9265f3 --- /dev/null +++ b/pkg/models/label_task.go @@ -0,0 +1,153 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import ( + "code.vikunja.io/web" + "github.com/go-xorm/builder" +) + +// Delete deletes a label on a task +// @Summary Remove a label from a task +// @Description Remove a label from a task. The user needs to have write-access to the list to be able do this. +// @tags labels +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param task path int true "Task ID" +// @Param label path int true "Label ID" +// @Success 200 {object} models.Label "The label was successfully removed." +// @Failure 403 {object} code.vikunja.io/web.HTTPError "Not allowed to remove the label." +// @Failure 404 {object} code.vikunja.io/web.HTTPError "Label not found." +// @Failure 500 {object} models.Message "Internal error" +// @Router /tasks/{task}/labels/{label} [delete] +func (l *LabelTask) Delete() (err error) { + _, err = x.Delete(&LabelTask{LabelID: l.LabelID, TaskID: l.TaskID}) + return err +} + +// Create adds a label to a task +// @Summary Add a label to a task +// @Description Add a label to a task. The user needs to have write-access to the list to be able do this. +// @tags labels +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param task path int true "Task ID" +// @Param label body models.Label true "The label object" +// @Success 200 {object} models.Label "The created label relation object." +// @Failure 400 {object} code.vikunja.io/web.HTTPError "Invalid label object provided." +// @Failure 403 {object} code.vikunja.io/web.HTTPError "Not allowed to add the label." +// @Failure 404 {object} code.vikunja.io/web.HTTPError "The label does not exist." +// @Failure 500 {object} models.Message "Internal error" +// @Router /tasks/{task}/labels [put] +func (l *LabelTask) Create(a web.Auth) (err error) { + // Check if the label is already added + exists, err := x.Exist(&LabelTask{LabelID: l.LabelID, TaskID: l.TaskID}) + if err != nil { + return err + } + if exists { + return ErrLabelIsAlreadyOnTask{l.LabelID, l.TaskID} + } + + // Insert it + _, err = x.Insert(l) + return err +} + +// ReadAll gets all labels on a task +// @Summary Get all labels on a task +// @Description Returns all labels which are assicociated with a given task. +// @tags labels +// @Accept json +// @Produce json +// @Param task path int true "Task ID" +// @Param p query int false "The page number. Used for pagination. If not provided, the first page of results is returned." +// @Param s query string false "Search labels by label text." +// @Security ApiKeyAuth +// @Success 200 {array} models.Label "The labels" +// @Failure 500 {object} models.Message "Internal error" +// @Router /tasks/{task}/labels [get] +func (l *LabelTask) ReadAll(search string, a web.Auth, page int) (labels interface{}, err error) { + u, err := getUserWithError(a) + if err != nil { + return nil, err + } + + // Check if the user has the right to see the task + task, err := GetListTaskByID(l.TaskID) + if err != nil { + return nil, err + } + + if !task.CanRead(a) { + return nil, ErrNoRightToSeeTask{l.TaskID, u.ID} + } + + return getLabelsByTaskIDs(search, u, page, []int64{l.TaskID}, false) +} + +type labelWithTaskID struct { + TaskID int64 + Label `xorm:"extends"` +} + +// Helper function to get all labels for a set of tasks +// Used when getting all labels for one task as well when getting all lables +func getLabelsByTaskIDs(search string, u *User, page int, taskIDs []int64, getUnusedLabels bool) (ls []*labelWithTaskID, err error) { + // Incl unused labels + var uidOrNil interface{} + var requestOrNil interface{} + if getUnusedLabels { + uidOrNil = u.ID + requestOrNil = "label_task.label_id != null OR labels.created_by_id = ?" + } + + // Get all labels associated with these labels + var labels []*labelWithTaskID + err = x.Table("labels"). + Select("labels.*, label_task.task_id"). + Join("LEFT", "label_task", "label_task.label_id = labels.id"). + Where(requestOrNil, uidOrNil). + Or(builder.In("label_task.task_id", taskIDs)). + And("labels.title LIKE ?", "%"+search+"%"). + GroupBy("labels.id"). + Limit(getLimitFromPageIndex(page)). + Find(&labels) + if err != nil { + return nil, err + } + + // Get all created by users + var userids []int64 + for _, l := range labels { + userids = append(userids, l.CreatedByID) + } + users := make(map[int64]*User) + err = x.In("id", userids).Find(&users) + if err != nil { + return nil, err + } + + // Put it all together + for in, l := range labels { + labels[in].CreatedBy = users[l.CreatedByID] + } + + return labels, err +} diff --git a/pkg/models/label_task_rights.go b/pkg/models/label_task_rights.go new file mode 100644 index 00000000000..2d0e4eb15f7 --- /dev/null +++ b/pkg/models/label_task_rights.go @@ -0,0 +1,62 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import ( + "code.vikunja.io/api/pkg/log" + "code.vikunja.io/web" +) + +// CanCreate checks if a user can add a label to a task +func (lt *LabelTask) CanCreate(a web.Auth) bool { + label, err := getLabelByIDSimple(lt.LabelID) + if err != nil { + log.Log.Errorf("Error during CanCreate for LabelTask: %v", err) + return false + } + + return label.hasAccessToLabel(a) && lt.canDoLabelTask(a) +} + +// CanDelete checks if a user can delete a label from a task +func (lt *LabelTask) CanDelete(a web.Auth) bool { + if !lt.canDoLabelTask(a) { + return false + } + + // We don't care here if the label exists or not. The only relevant thing here is if the relation already exists, + // throw an error. + exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) + if err != nil { + log.Log.Errorf("Error during CanDelete for LabelTask: %v", err) + return false + } + return exists +} + +// Helper function to check if a user can write to a task +// + is able to see the label +// always the same check for either deleting or adding a label to a task +func (lt *LabelTask) canDoLabelTask(a web.Auth) bool { + // A user can add a label to a task if he can write to the task + task, err := getTaskByIDSimple(lt.TaskID) + if err != nil { + log.Log.Error("Error occurred during canDoLabelTask for LabelTask: %v", err) + return false + } + return task.CanUpdate(a) +} diff --git a/pkg/models/label_task_test.go b/pkg/models/label_task_test.go new file mode 100644 index 00000000000..4956bbf7b23 --- /dev/null +++ b/pkg/models/label_task_test.go @@ -0,0 +1,279 @@ +package models + +import ( + "reflect" + "runtime" + "testing" + + "code.vikunja.io/web" +) + +func TestLabelTask_ReadAll(t *testing.T) { + type fields struct { + ID int64 + TaskID int64 + LabelID int64 + Created int64 + CRUDable web.CRUDable + Rights web.Rights + } + type args struct { + search string + a web.Auth + page int + } + tests := []struct { + name string + fields fields + args args + wantLabels interface{} + wantErr bool + errType func(error) bool + }{ + { + name: "normal", + fields: fields{ + TaskID: 1, + }, + args: args{ + a: &User{ID: 1}, + }, + wantLabels: []*labelWithTaskID{ + { + TaskID: 1, + Label: Label{ + ID: 4, + Title: "Label #4 - visible via other task", + CreatedByID: 2, + CreatedBy: &User{ + ID: 2, + Username: "user2", + Password: "1234", + Email: "user2@example.com", + }, + }, + }, + }, + }, + { + name: "no right to see the task", + fields: fields{ + TaskID: 14, + }, + args: args{ + a: &User{ID: 1}, + }, + wantErr: true, + errType: IsErrNoRightToSeeTask, + }, + { + name: "nonexistant task", + fields: fields{ + TaskID: 9999, + }, + args: args{ + a: &User{ID: 1}, + }, + wantErr: true, + errType: IsErrListTaskDoesNotExist, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &LabelTask{ + ID: tt.fields.ID, + TaskID: tt.fields.TaskID, + LabelID: tt.fields.LabelID, + Created: tt.fields.Created, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + gotLabels, err := l.ReadAll(tt.args.search, tt.args.a, tt.args.page) + if (err != nil) != tt.wantErr { + t.Errorf("LabelTask.ReadAll() error = %v, wantErr %v", err, tt.wantErr) + return + } + if (err != nil) && tt.wantErr && !tt.errType(err) { + t.Errorf("LabelTask.ReadAll() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) + } + if !reflect.DeepEqual(gotLabels, tt.wantLabels) { + t.Errorf("LabelTask.ReadAll() = %v, want %v", gotLabels, tt.wantLabels) + } + }) + } +} + +func TestLabelTask_Create(t *testing.T) { + type fields struct { + ID int64 + TaskID int64 + LabelID int64 + Created int64 + CRUDable web.CRUDable + Rights web.Rights + } + type args struct { + a web.Auth + } + tests := []struct { + name string + fields fields + args args + wantErr bool + errType func(error) bool + wantForbidden bool + }{ + { + name: "normal", + fields: fields{ + TaskID: 1, + LabelID: 1, + }, + args: args{ + a: &User{ID: 1}, + }, + }, + { + name: "already existing", + fields: fields{ + TaskID: 1, + LabelID: 1, + }, + args: args{ + a: &User{ID: 1}, + }, + wantErr: true, + errType: IsErrLabelIsAlreadyOnTask, + }, + { + name: "nonexisting label", + fields: fields{ + TaskID: 1, + LabelID: 9999, + }, + args: args{ + a: &User{ID: 1}, + }, + wantForbidden: true, + }, + { + name: "nonexisting task", + fields: fields{ + TaskID: 9999, + LabelID: 1, + }, + args: args{ + a: &User{ID: 1}, + }, + wantForbidden: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &LabelTask{ + ID: tt.fields.ID, + TaskID: tt.fields.TaskID, + LabelID: tt.fields.LabelID, + Created: tt.fields.Created, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + if !l.CanCreate(tt.args.a) && !tt.wantForbidden { + t.Errorf("LabelTask.CanCreate() forbidden, want %v", tt.wantForbidden) + } + err := l.Create(tt.args.a) + if (err != nil) != tt.wantErr { + t.Errorf("LabelTask.Create() error = %v, wantErr %v", err, tt.wantErr) + } + if (err != nil) && tt.wantErr && !tt.errType(err) { + t.Errorf("LabelTask.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) + } + }) + } +} + +func TestLabelTask_Delete(t *testing.T) { + type fields struct { + ID int64 + TaskID int64 + LabelID int64 + Created int64 + CRUDable web.CRUDable + Rights web.Rights + } + tests := []struct { + name string + fields fields + wantErr bool + errType func(error) bool + auth web.Auth + wantForbidden bool + }{ + { + name: "normal", + fields: fields{ + TaskID: 1, + LabelID: 1, + }, + auth: &User{ID: 1}, + }, + { + name: "delete nonexistant", + fields: fields{ + TaskID: 1, + LabelID: 1, + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + { + name: "nonexisting label", + fields: fields{ + TaskID: 1, + LabelID: 9999, + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + { + name: "nonexisting task", + fields: fields{ + TaskID: 9999, + LabelID: 1, + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + { + name: "existing, but forbidden task", + fields: fields{ + TaskID: 14, + LabelID: 1, + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &LabelTask{ + ID: tt.fields.ID, + TaskID: tt.fields.TaskID, + LabelID: tt.fields.LabelID, + Created: tt.fields.Created, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + if !l.CanDelete(tt.auth) && !tt.wantForbidden { + t.Errorf("LabelTask.CanDelete() forbidden, want %v", tt.wantForbidden) + } + err := l.Delete() + if (err != nil) != tt.wantErr { + t.Errorf("LabelTask.Delete() error = %v, wantErr %v", err, tt.wantErr) + } + if (err != nil) && tt.wantErr && !tt.errType(err) { + t.Errorf("LabelTask.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) + } + }) + } +} diff --git a/pkg/models/label_test.go b/pkg/models/label_test.go new file mode 100644 index 00000000000..39ea96e5c1b --- /dev/null +++ b/pkg/models/label_test.go @@ -0,0 +1,452 @@ +// Vikunja is a todo-list application to facilitate your life. +// Copyright 2018 Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package models + +import ( + "reflect" + "runtime" + "testing" + + "code.vikunja.io/web" +) + +func TestLabel_ReadAll(t *testing.T) { + type fields struct { + ID int64 + Title string + Description string + HexColor string + CreatedByID int64 + CreatedBy *User + Created int64 + Updated int64 + CRUDable web.CRUDable + Rights web.Rights + } + type args struct { + search string + a web.Auth + page int + } + user1 := &User{ + ID: 1, + Username: "user1", + Password: "1234", + Email: "user1@example.com", + } + tests := []struct { + name string + fields fields + args args + wantLs interface{} + wantErr bool + }{ + { + name: "normal", + args: args{ + a: &User{ID: 1}, + }, + wantLs: []*labelWithTaskID{ + { + Label: Label{ + ID: 1, + Title: "Label #1", + CreatedByID: 1, + CreatedBy: user1, + }, + }, + { + Label: Label{ + ID: 2, + Title: "Label #2", + CreatedByID: 1, + CreatedBy: user1, + }, + }, + { + TaskID: 1, + Label: Label{ + ID: 4, + Title: "Label #4 - visible via other task", + CreatedByID: 2, + CreatedBy: &User{ + ID: 2, + Username: "user2", + Password: "1234", + Email: "user2@example.com", + }, + }, + }, + }, + }, + { + name: "invalid user", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &Label{ + ID: tt.fields.ID, + Title: tt.fields.Title, + Description: tt.fields.Description, + HexColor: tt.fields.HexColor, + CreatedByID: tt.fields.CreatedByID, + CreatedBy: tt.fields.CreatedBy, + Created: tt.fields.Created, + Updated: tt.fields.Updated, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + gotLs, err := l.ReadAll(tt.args.search, tt.args.a, tt.args.page) + if (err != nil) != tt.wantErr { + t.Errorf("Label.ReadAll() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotLs, tt.wantLs) { + t.Errorf("Label.ReadAll() = %v, want %v", gotLs, tt.wantLs) + } + }) + } +} + +func TestLabel_ReadOne(t *testing.T) { + type fields struct { + ID int64 + Title string + Description string + HexColor string + CreatedByID int64 + CreatedBy *User + Created int64 + Updated int64 + CRUDable web.CRUDable + Rights web.Rights + } + user1 := &User{ + ID: 1, + Username: "user1", + Password: "1234", + Email: "user1@example.com", + } + tests := []struct { + name string + fields fields + want *Label + wantErr bool + errType func(error) bool + auth web.Auth + wantForbidden bool + }{ + { + name: "Get label #1", + fields: fields{ + ID: 1, + }, + want: &Label{ + ID: 1, + Title: "Label #1", + CreatedByID: 1, + CreatedBy: user1, + }, + auth: &User{ID: 1}, + }, + { + name: "Get nonexistant label", + fields: fields{ + ID: 9999, + }, + wantErr: true, + errType: IsErrLabelDoesNotExist, + wantForbidden: true, + auth: &User{ID: 1}, + }, + { + name: "no rights", + fields: fields{ + ID: 3, + }, + wantForbidden: true, + auth: &User{ID: 1}, + }, + { + name: "Get label #4 - other user", + fields: fields{ + ID: 4, + }, + want: &Label{ + ID: 4, + Title: "Label #4 - visible via other task", + CreatedByID: 2, + CreatedBy: &User{ + ID: 2, + Username: "user2", + Password: "1234", + Email: "user2@example.com", + }, + }, + auth: &User{ID: 1}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &Label{ + ID: tt.fields.ID, + Title: tt.fields.Title, + Description: tt.fields.Description, + HexColor: tt.fields.HexColor, + CreatedByID: tt.fields.CreatedByID, + CreatedBy: tt.fields.CreatedBy, + Created: tt.fields.Created, + Updated: tt.fields.Updated, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + + if !l.CanRead(tt.auth) && !tt.wantForbidden { + t.Errorf("Label.CanRead() forbidden, want %v", tt.wantForbidden) + } + err := l.ReadOne() + if (err != nil) != tt.wantErr { + t.Errorf("Label.ReadOne() error = %v, wantErr %v", err, tt.wantErr) + } + if (err != nil) && tt.wantErr && !tt.errType(err) { + t.Errorf("Label.ReadOne() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) + } + if !reflect.DeepEqual(l, tt.want) && !tt.wantErr && !tt.wantForbidden { + t.Errorf("Label.ReadOne() = %v, want %v", l, tt.want) + } + }) + } +} + +func TestLabel_Create(t *testing.T) { + type fields struct { + ID int64 + Title string + Description string + HexColor string + CreatedByID int64 + CreatedBy *User + Created int64 + Updated int64 + CRUDable web.CRUDable + Rights web.Rights + } + type args struct { + a web.Auth + } + tests := []struct { + name string + fields fields + args args + wantErr bool + wantForbidden bool + }{ + { + name: "normal", + fields: fields{ + Title: "Test #1", + Description: "Lorem Ipsum", + HexColor: "ffccff", + }, + args: args{ + a: &User{ID: 1}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &Label{ + ID: tt.fields.ID, + Title: tt.fields.Title, + Description: tt.fields.Description, + HexColor: tt.fields.HexColor, + CreatedByID: tt.fields.CreatedByID, + CreatedBy: tt.fields.CreatedBy, + Created: tt.fields.Created, + Updated: tt.fields.Updated, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + if !l.CanCreate(tt.args.a) && !tt.wantForbidden { + t.Errorf("Label.CanCreate() forbidden, want %v", tt.wantForbidden) + } + if err := l.Create(tt.args.a); (err != nil) != tt.wantErr { + t.Errorf("Label.Create() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestLabel_Update(t *testing.T) { + type fields struct { + ID int64 + Title string + Description string + HexColor string + CreatedByID int64 + CreatedBy *User + Created int64 + Updated int64 + CRUDable web.CRUDable + Rights web.Rights + } + tests := []struct { + name string + fields fields + wantErr bool + auth web.Auth + wantForbidden bool + }{ + { + name: "normal", + fields: fields{ + ID: 1, + Title: "new and better", + }, + auth: &User{ID: 1}, + }, + { + name: "nonexisting", + fields: fields{ + ID: 99999, + Title: "new and better", + }, + auth: &User{ID: 1}, + wantForbidden: true, + wantErr: true, + }, + { + name: "no rights", + fields: fields{ + ID: 3, + Title: "new and better", + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + { + name: "no rights other creator but access", + fields: fields{ + ID: 4, + Title: "new and better", + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &Label{ + ID: tt.fields.ID, + Title: tt.fields.Title, + Description: tt.fields.Description, + HexColor: tt.fields.HexColor, + CreatedByID: tt.fields.CreatedByID, + CreatedBy: tt.fields.CreatedBy, + Created: tt.fields.Created, + Updated: tt.fields.Updated, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + if !l.CanUpdate(tt.auth) && !tt.wantForbidden { + t.Errorf("Label.CanUpdate() forbidden, want %v", tt.wantForbidden) + } + if err := l.Update(); (err != nil) != tt.wantErr { + t.Errorf("Label.Update() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestLabel_Delete(t *testing.T) { + type fields struct { + ID int64 + Title string + Description string + HexColor string + CreatedByID int64 + CreatedBy *User + Created int64 + Updated int64 + CRUDable web.CRUDable + Rights web.Rights + } + tests := []struct { + name string + fields fields + wantErr bool + auth web.Auth + wantForbidden bool + }{ + + { + name: "normal", + fields: fields{ + ID: 1, + }, + auth: &User{ID: 1}, + }, + { + name: "nonexisting", + fields: fields{ + ID: 99999, + }, + auth: &User{ID: 1}, + wantForbidden: true, // When the label does not exist, it is forbidden. We should fix this, but for everything. + }, + { + name: "no rights", + fields: fields{ + ID: 3, + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + { + name: "no rights but visible", + fields: fields{ + ID: 4, + }, + auth: &User{ID: 1}, + wantForbidden: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &Label{ + ID: tt.fields.ID, + Title: tt.fields.Title, + Description: tt.fields.Description, + HexColor: tt.fields.HexColor, + CreatedByID: tt.fields.CreatedByID, + CreatedBy: tt.fields.CreatedBy, + Created: tt.fields.Created, + Updated: tt.fields.Updated, + CRUDable: tt.fields.CRUDable, + Rights: tt.fields.Rights, + } + if !l.CanDelete(tt.auth) && !tt.wantForbidden { + t.Errorf("Label.CanDelete() forbidden, want %v", tt.wantForbidden) + } + if err := l.Delete(); (err != nil) != tt.wantErr { + t.Errorf("Label.Delete() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/models/list_tasks.go b/pkg/models/list_tasks.go index b52319a7b89..24368d7a189 100644 --- a/pkg/models/list_tasks.go +++ b/pkg/models/list_tasks.go @@ -23,20 +23,21 @@ import ( // ListTask represents an task in a todolist type ListTask struct { - ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id" param:"listtask"` - Text string `xorm:"varchar(250)" json:"text" valid:"runelength(3|250)"` - Description string `xorm:"varchar(250)" json:"description" valid:"runelength(0|250)"` - Done bool `xorm:"INDEX" json:"done"` - DueDateUnix int64 `xorm:"int(11) INDEX" json:"dueDate"` - RemindersUnix []int64 `xorm:"JSON TEXT" json:"reminderDates"` - CreatedByID int64 `xorm:"int(11)" json:"-"` // ID of the user who put that task on the list - ListID int64 `xorm:"int(11) INDEX" json:"listID" param:"list"` - RepeatAfter int64 `xorm:"int(11) INDEX" json:"repeatAfter"` - ParentTaskID int64 `xorm:"int(11) INDEX" json:"parentTaskID"` - Priority int64 `xorm:"int(11)" json:"priority"` - StartDateUnix int64 `xorm:"int(11) INDEX" json:"startDate"` - EndDateUnix int64 `xorm:"int(11) INDEX" json:"endDate"` - Assignees []*User `xorm:"-" json:"assignees"` + ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id" param:"listtask"` + Text string `xorm:"varchar(250)" json:"text" valid:"runelength(3|250)"` + Description string `xorm:"varchar(250)" json:"description" valid:"runelength(0|250)"` + Done bool `xorm:"INDEX" json:"done"` + DueDateUnix int64 `xorm:"int(11) INDEX" json:"dueDate"` + RemindersUnix []int64 `xorm:"JSON TEXT" json:"reminderDates"` + CreatedByID int64 `xorm:"int(11)" json:"-"` // ID of the user who put that task on the list + ListID int64 `xorm:"int(11) INDEX" json:"listID" param:"list"` + RepeatAfter int64 `xorm:"int(11) INDEX" json:"repeatAfter"` + ParentTaskID int64 `xorm:"int(11) INDEX" json:"parentTaskID"` + Priority int64 `xorm:"int(11)" json:"priority"` + StartDateUnix int64 `xorm:"int(11) INDEX" json:"startDate"` + EndDateUnix int64 `xorm:"int(11) INDEX" json:"endDate"` + Assignees []*User `xorm:"-" json:"assignees"` + Labels []*Label `xorm:"-" json:"labels"` Sorting string `xorm:"-" json:"-" param:"sort"` // Parameter to sort by StartDateSortUnix int64 `xorm:"-" json:"-" param:"startdatefilter"` @@ -61,8 +62,8 @@ func (ListTask) TableName() string { // ListTaskAssginee represents an assignment of a user to a task type ListTaskAssginee struct { ID int64 `xorm:"int(11) autoincr not null unique pk"` - TaskID int64 `xorm:"int(11) not null"` - UserID int64 `xorm:"int(11) not null"` + TaskID int64 `xorm:"int(11) INDEX not null"` + UserID int64 `xorm:"int(11) INDEX not null"` Created int64 `xorm:"created"` } @@ -79,37 +80,24 @@ type ListTaskAssigneeWithUser struct { // GetTasksByListID gets all todotasks for a list func GetTasksByListID(listID int64) (tasks []*ListTask, err error) { - err = x.Where("list_id = ?", listID).Find(&tasks) + // make a map so we can put in a lot of other stuff more easily + taskMap := make(map[int64]*ListTask, len(tasks)) + err = x.Where("list_id = ?", listID).Find(&taskMap) if err != nil { return } - // No need to iterate over users if the list doesn't has tasks - if len(tasks) == 0 { + // No need to iterate over users and stuff if the list doesn't has tasks + if len(taskMap) == 0 { return } - // make a map so we can put in subtasks more easily - taskMap := make(map[int64]*ListTask, len(tasks)) - // Get all users & task ids and put them into the array var userIDs []int64 var taskIDs []int64 - for _, i := range tasks { + for _, i := range taskMap { taskIDs = append(taskIDs, i.ID) - found := false - for _, u := range userIDs { - if i.CreatedByID == u { - found = true - break - } - } - - if !found { - userIDs = append(userIDs, i.CreatedByID) - } - - taskMap[i.ID] = i + userIDs = append(userIDs, i.CreatedByID) } // Get all assignees @@ -124,7 +112,18 @@ func GetTasksByListID(listID int64) (tasks []*ListTask, err error) { } } - var users []User + // Get all labels for the tasks + labels, err := getLabelsByTaskIDs("", &User{}, -1, taskIDs, false) + if err != nil { + return + } + for _, l := range labels { + if l != nil { + taskMap[l.TaskID].Labels = append(taskMap[l.TaskID].Labels, &l.Label) + } + } + + users := make(map[int64]*User) err = x.In("id", userIDs).Find(&users) if err != nil { return @@ -134,12 +133,7 @@ func GetTasksByListID(listID int64) (tasks []*ListTask, err error) { for _, task := range taskMap { // Make created by user objects - for _, u := range users { - if task.CreatedByID == u.ID { - taskMap[task.ID].CreatedBy = u - break - } - } + taskMap[task.ID].CreatedBy = *users[task.CreatedByID] // Reorder all subtasks if task.ParentTaskID != 0 { @@ -154,7 +148,7 @@ func GetTasksByListID(listID int64) (tasks []*ListTask, err error) { tasks = append(tasks, t) } - // Sort the output. In Go, contents on a map are put on that map in no particular order. + // Sort the output. In Go, contents on a map are put on that map in no particular order (saved on heap). // Because of this, tasks are not sorted anymore in the output, this leads to confiusion. // To avoid all this, we need to sort the slice afterwards sort.Slice(tasks, func(i, j int) bool { @@ -174,19 +168,27 @@ func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*ListTaskAssi return } -// GetListTaskByID returns all tasks a list has -func GetListTaskByID(listTaskID int64) (listTask ListTask, err error) { - if listTaskID < 1 { - return ListTask{}, ErrListTaskDoesNotExist{listTaskID} +func getTaskByIDSimple(taskID int64) (task ListTask, err error) { + if taskID < 1 { + return ListTask{}, ErrListTaskDoesNotExist{taskID} } - exists, err := x.ID(listTaskID).Get(&listTask) + exists, err := x.ID(taskID).Get(&task) if err != nil { return ListTask{}, err } if !exists { - return ListTask{}, ErrListTaskDoesNotExist{listTaskID} + return ListTask{}, ErrListTaskDoesNotExist{taskID} + } + return +} + +// GetListTaskByID returns all tasks a list has +func GetListTaskByID(listTaskID int64) (listTask ListTask, err error) { + listTask, err = getTaskByIDSimple(listTaskID) + if err != nil { + return } u, err := GetUserByID(listTask.CreatedByID) diff --git a/pkg/models/list_tasks_rights.go b/pkg/models/list_tasks_rights.go index 2b6f9d972d1..c9b60617e52 100644 --- a/pkg/models/list_tasks_rights.go +++ b/pkg/models/list_tasks_rights.go @@ -43,15 +43,19 @@ func (t *ListTask) CanUpdate(a web.Auth) bool { doer := getUserForRights(a) // Get the task - lI, err := GetListTaskByID(t.ID) + lI, err := getTaskByIDSimple(t.ID) if err != nil { - log.Log.Error("Error occurred during CanDelete for ListTask: %s", err) + log.Log.Error("Error occurred during CanUpdate (getTaskByIDSimple) for ListTask: %s", err) return false } // A user can update an task if he has write acces to its list l := &List{ID: lI.ListID} - l.ReadOne() + err = l.GetSimpleByID() + if err != nil { + log.Log.Error("Error occurred during CanUpdate (ReadOne) for ListTask: %s", err) + return false + } return l.CanWrite(doer) } @@ -64,3 +68,10 @@ func (t *ListTask) CanCreate(a web.Auth) bool { l.ReadOne() return l.CanWrite(doer) } + +// CanRead determines if a user can read a task +func (t *ListTask) CanRead(a web.Auth) bool { + // A user can read a task if it has access to the list + list := &List{ID: t.ListID} + return list.CanRead(a) +} diff --git a/pkg/models/models.go b/pkg/models/models.go index 5b95173c41b..2ec7f570ed3 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -69,6 +69,8 @@ func init() { new(ListUser), new(NamespaceUser), new(ListTaskAssginee), + new(Label), + new(LabelTask), ) tablesWithPointer = append(tables, @@ -83,6 +85,8 @@ func init() { &ListUser{}, &NamespaceUser{}, &ListTaskAssginee{}, + &Label{}, + &LabelTask{}, ) } diff --git a/pkg/routes/routes.go b/pkg/routes/routes.go index 89b51d9abd4..e3daa2f1478 100644 --- a/pkg/routes/routes.go +++ b/pkg/routes/routes.go @@ -229,6 +229,26 @@ func RegisterRoutes(e *echo.Echo) { } a.POST("/tasks/bulk", bulkTaskHandler.UpdateWeb) + labelTaskHandler := &handler.WebHandler{ + EmptyStruct: func() handler.CObject { + return &models.LabelTask{} + }, + } + a.PUT("/tasks/:listtask/labels", labelTaskHandler.CreateWeb) + a.DELETE("/tasks/:listtask/labels/:label", labelTaskHandler.DeleteWeb) + a.GET("/tasks/:listtask/labels", labelTaskHandler.ReadAllWeb) + + labelHandler := &handler.WebHandler{ + EmptyStruct: func() handler.CObject { + return &models.Label{} + }, + } + a.GET("/labels", labelHandler.ReadAllWeb) + a.GET("/labels/:label", labelHandler.ReadOneWeb) + a.PUT("/labels", labelHandler.CreateWeb) + a.DELETE("/labels/:label", labelHandler.DeleteWeb) + a.POST("/labels/:label", labelHandler.UpdateWeb) + listTeamHandler := &handler.WebHandler{ EmptyStruct: func() handler.CObject { return &models.TeamList{} diff --git a/vendor/honnef.co/go/tools/callgraph/callgraph.go b/vendor/honnef.co/go/tools/callgraph/callgraph.go new file mode 100644 index 00000000000..d93a20a3a16 --- /dev/null +++ b/vendor/honnef.co/go/tools/callgraph/callgraph.go @@ -0,0 +1,129 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + +Package callgraph defines the call graph and various algorithms +and utilities to operate on it. + +A call graph is a labelled directed graph whose nodes represent +functions and whose edge labels represent syntactic function call +sites. The presence of a labelled edge (caller, site, callee) +indicates that caller may call callee at the specified call site. + +A call graph is a multigraph: it may contain multiple edges (caller, +*, callee) connecting the same pair of nodes, so long as the edges +differ by label; this occurs when one function calls another function +from multiple call sites. Also, it may contain multiple edges +(caller, site, *) that differ only by callee; this indicates a +polymorphic call. + +A SOUND call graph is one that overapproximates the dynamic calling +behaviors of the program in all possible executions. One call graph +is more PRECISE than another if it is a smaller overapproximation of +the dynamic behavior. + +All call graphs have a synthetic root node which is responsible for +calling main() and init(). + +Calls to built-in functions (e.g. panic, println) are not represented +in the call graph; they are treated like built-in operators of the +language. + +*/ +package callgraph // import "honnef.co/go/tools/callgraph" + +// TODO(adonovan): add a function to eliminate wrappers from the +// callgraph, preserving topology. +// More generally, we could eliminate "uninteresting" nodes such as +// nodes from packages we don't care about. + +import ( + "fmt" + "go/token" + + "honnef.co/go/tools/ssa" +) + +// A Graph represents a call graph. +// +// A graph may contain nodes that are not reachable from the root. +// If the call graph is sound, such nodes indicate unreachable +// functions. +// +type Graph struct { + Root *Node // the distinguished root node + Nodes map[*ssa.Function]*Node // all nodes by function +} + +// New returns a new Graph with the specified root node. +func New(root *ssa.Function) *Graph { + g := &Graph{Nodes: make(map[*ssa.Function]*Node)} + g.Root = g.CreateNode(root) + return g +} + +// CreateNode returns the Node for fn, creating it if not present. +func (g *Graph) CreateNode(fn *ssa.Function) *Node { + n, ok := g.Nodes[fn] + if !ok { + n = &Node{Func: fn, ID: len(g.Nodes)} + g.Nodes[fn] = n + } + return n +} + +// A Node represents a node in a call graph. +type Node struct { + Func *ssa.Function // the function this node represents + ID int // 0-based sequence number + In []*Edge // unordered set of incoming call edges (n.In[*].Callee == n) + Out []*Edge // unordered set of outgoing call edges (n.Out[*].Caller == n) +} + +func (n *Node) String() string { + return fmt.Sprintf("n%d:%s", n.ID, n.Func) +} + +// A Edge represents an edge in the call graph. +// +// Site is nil for edges originating in synthetic or intrinsic +// functions, e.g. reflect.Call or the root of the call graph. +type Edge struct { + Caller *Node + Site ssa.CallInstruction + Callee *Node +} + +func (e Edge) String() string { + return fmt.Sprintf("%s --> %s", e.Caller, e.Callee) +} + +func (e Edge) Description() string { + var prefix string + switch e.Site.(type) { + case nil: + return "synthetic call" + case *ssa.Go: + prefix = "concurrent " + case *ssa.Defer: + prefix = "deferred " + } + return prefix + e.Site.Common().Description() +} + +func (e Edge) Pos() token.Pos { + if e.Site == nil { + return token.NoPos + } + return e.Site.Pos() +} + +// AddEdge adds the edge (caller, site, callee) to the call graph. +// Elimination of duplicate edges is the caller's responsibility. +func AddEdge(caller *Node, site ssa.CallInstruction, callee *Node) { + e := &Edge{caller, site, callee} + callee.In = append(callee.In, e) + caller.Out = append(caller.Out, e) +} diff --git a/vendor/honnef.co/go/tools/callgraph/static/static.go b/vendor/honnef.co/go/tools/callgraph/static/static.go new file mode 100644 index 00000000000..5444e841134 --- /dev/null +++ b/vendor/honnef.co/go/tools/callgraph/static/static.go @@ -0,0 +1,35 @@ +// Package static computes the call graph of a Go program containing +// only static call edges. +package static // import "honnef.co/go/tools/callgraph/static" + +import ( + "honnef.co/go/tools/callgraph" + "honnef.co/go/tools/ssa" + "honnef.co/go/tools/ssa/ssautil" +) + +// CallGraph computes the call graph of the specified program +// considering only static calls. +// +func CallGraph(prog *ssa.Program) *callgraph.Graph { + cg := callgraph.New(nil) // TODO(adonovan) eliminate concept of rooted callgraph + + // TODO(adonovan): opt: use only a single pass over the ssa.Program. + // TODO(adonovan): opt: this is slower than RTA (perhaps because + // the lower precision means so many edges are allocated)! + for f := range ssautil.AllFunctions(prog) { + fnode := cg.CreateNode(f) + for _, b := range f.Blocks { + for _, instr := range b.Instrs { + if site, ok := instr.(ssa.CallInstruction); ok { + if g := site.Common().StaticCallee(); g != nil { + gnode := cg.CreateNode(g) + callgraph.AddEdge(fnode, site, gnode) + } + } + } + } + } + + return cg +} diff --git a/vendor/honnef.co/go/tools/callgraph/util.go b/vendor/honnef.co/go/tools/callgraph/util.go new file mode 100644 index 00000000000..7aeda964159 --- /dev/null +++ b/vendor/honnef.co/go/tools/callgraph/util.go @@ -0,0 +1,181 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package callgraph + +import "honnef.co/go/tools/ssa" + +// This file provides various utilities over call graphs, such as +// visitation and path search. + +// CalleesOf returns a new set containing all direct callees of the +// caller node. +// +func CalleesOf(caller *Node) map[*Node]bool { + callees := make(map[*Node]bool) + for _, e := range caller.Out { + callees[e.Callee] = true + } + return callees +} + +// GraphVisitEdges visits all the edges in graph g in depth-first order. +// The edge function is called for each edge in postorder. If it +// returns non-nil, visitation stops and GraphVisitEdges returns that +// value. +// +func GraphVisitEdges(g *Graph, edge func(*Edge) error) error { + seen := make(map[*Node]bool) + var visit func(n *Node) error + visit = func(n *Node) error { + if !seen[n] { + seen[n] = true + for _, e := range n.Out { + if err := visit(e.Callee); err != nil { + return err + } + if err := edge(e); err != nil { + return err + } + } + } + return nil + } + for _, n := range g.Nodes { + if err := visit(n); err != nil { + return err + } + } + return nil +} + +// PathSearch finds an arbitrary path starting at node start and +// ending at some node for which isEnd() returns true. On success, +// PathSearch returns the path as an ordered list of edges; on +// failure, it returns nil. +// +func PathSearch(start *Node, isEnd func(*Node) bool) []*Edge { + stack := make([]*Edge, 0, 32) + seen := make(map[*Node]bool) + var search func(n *Node) []*Edge + search = func(n *Node) []*Edge { + if !seen[n] { + seen[n] = true + if isEnd(n) { + return stack + } + for _, e := range n.Out { + stack = append(stack, e) // push + if found := search(e.Callee); found != nil { + return found + } + stack = stack[:len(stack)-1] // pop + } + } + return nil + } + return search(start) +} + +// DeleteSyntheticNodes removes from call graph g all nodes for +// synthetic functions (except g.Root and package initializers), +// preserving the topology. In effect, calls to synthetic wrappers +// are "inlined". +// +func (g *Graph) DeleteSyntheticNodes() { + // Measurements on the standard library and go.tools show that + // resulting graph has ~15% fewer nodes and 4-8% fewer edges + // than the input. + // + // Inlining a wrapper of in-degree m, out-degree n adds m*n + // and removes m+n edges. Since most wrappers are monomorphic + // (n=1) this results in a slight reduction. Polymorphic + // wrappers (n>1), e.g. from embedding an interface value + // inside a struct to satisfy some interface, cause an + // increase in the graph, but they seem to be uncommon. + + // Hash all existing edges to avoid creating duplicates. + edges := make(map[Edge]bool) + for _, cgn := range g.Nodes { + for _, e := range cgn.Out { + edges[*e] = true + } + } + for fn, cgn := range g.Nodes { + if cgn == g.Root || fn.Synthetic == "" || isInit(cgn.Func) { + continue // keep + } + for _, eIn := range cgn.In { + for _, eOut := range cgn.Out { + newEdge := Edge{eIn.Caller, eIn.Site, eOut.Callee} + if edges[newEdge] { + continue // don't add duplicate + } + AddEdge(eIn.Caller, eIn.Site, eOut.Callee) + edges[newEdge] = true + } + } + g.DeleteNode(cgn) + } +} + +func isInit(fn *ssa.Function) bool { + return fn.Pkg != nil && fn.Pkg.Func("init") == fn +} + +// DeleteNode removes node n and its edges from the graph g. +// (NB: not efficient for batch deletion.) +func (g *Graph) DeleteNode(n *Node) { + n.deleteIns() + n.deleteOuts() + delete(g.Nodes, n.Func) +} + +// deleteIns deletes all incoming edges to n. +func (n *Node) deleteIns() { + for _, e := range n.In { + removeOutEdge(e) + } + n.In = nil +} + +// deleteOuts deletes all outgoing edges from n. +func (n *Node) deleteOuts() { + for _, e := range n.Out { + removeInEdge(e) + } + n.Out = nil +} + +// removeOutEdge removes edge.Caller's outgoing edge 'edge'. +func removeOutEdge(edge *Edge) { + caller := edge.Caller + n := len(caller.Out) + for i, e := range caller.Out { + if e == edge { + // Replace it with the final element and shrink the slice. + caller.Out[i] = caller.Out[n-1] + caller.Out[n-1] = nil // aid GC + caller.Out = caller.Out[:n-1] + return + } + } + panic("edge not found: " + edge.String()) +} + +// removeInEdge removes edge.Callee's incoming edge 'edge'. +func removeInEdge(edge *Edge) { + caller := edge.Callee + n := len(caller.In) + for i, e := range caller.In { + if e == edge { + // Replace it with the final element and shrink the slice. + caller.In[i] = caller.In[n-1] + caller.In[n-1] = nil // aid GC + caller.In = caller.In[:n-1] + return + } + } + panic("edge not found: " + edge.String()) +} diff --git a/vendor/honnef.co/go/tools/cmd/staticcheck/README.md b/vendor/honnef.co/go/tools/cmd/staticcheck/README.md new file mode 100644 index 00000000000..7fb4dabf64e --- /dev/null +++ b/vendor/honnef.co/go/tools/cmd/staticcheck/README.md @@ -0,0 +1,16 @@ +# staticcheck + +_staticcheck_ is `go vet` on steroids, applying a ton of static analysis +checks you might be used to from tools like ReSharper for C#. + +## Installation + +Staticcheck requires Go 1.6 or later. + + go get honnef.co/go/tools/cmd/staticcheck + +## Documentation + +Detailed documentation can be found on +[staticcheck.io](https://staticcheck.io/docs/staticcheck). + diff --git a/vendor/honnef.co/go/tools/cmd/staticcheck/staticcheck.go b/vendor/honnef.co/go/tools/cmd/staticcheck/staticcheck.go new file mode 100644 index 00000000000..5e6d6f9cc78 --- /dev/null +++ b/vendor/honnef.co/go/tools/cmd/staticcheck/staticcheck.go @@ -0,0 +1,23 @@ +// staticcheck detects a myriad of bugs and inefficiencies in your +// code. +package main // import "honnef.co/go/tools/cmd/staticcheck" + +import ( + "os" + + "honnef.co/go/tools/lint/lintutil" + "honnef.co/go/tools/staticcheck" +) + +func main() { + fs := lintutil.FlagSet("staticcheck") + gen := fs.Bool("generated", false, "Check generated code") + fs.Parse(os.Args[1:]) + c := staticcheck.NewChecker() + c.CheckGenerated = *gen + cfg := lintutil.CheckerConfig{ + Checker: c, + ExitNonZero: true, + } + lintutil.ProcessFlagSet([]lintutil.CheckerConfig{cfg}, fs) +} diff --git a/vendor/honnef.co/go/tools/deprecated/stdlib.go b/vendor/honnef.co/go/tools/deprecated/stdlib.go new file mode 100644 index 00000000000..b6b217c3e37 --- /dev/null +++ b/vendor/honnef.co/go/tools/deprecated/stdlib.go @@ -0,0 +1,54 @@ +package deprecated + +type Deprecation struct { + DeprecatedSince int + AlternativeAvailableSince int +} + +var Stdlib = map[string]Deprecation{ + "image/jpeg.Reader": {4, 0}, + // FIXME(dh): AllowBinary isn't being detected as deprecated + // because the comment has a newline right after "Deprecated:" + "go/build.AllowBinary": {7, 7}, + "(archive/zip.FileHeader).CompressedSize": {1, 1}, + "(archive/zip.FileHeader).UncompressedSize": {1, 1}, + "(go/doc.Package).Bugs": {1, 1}, + "os.SEEK_SET": {7, 7}, + "os.SEEK_CUR": {7, 7}, + "os.SEEK_END": {7, 7}, + "(net.Dialer).Cancel": {7, 7}, + "runtime.CPUProfile": {9, 0}, + "compress/flate.ReadError": {6, 6}, + "compress/flate.WriteError": {6, 6}, + "path/filepath.HasPrefix": {0, 0}, + "(net/http.Transport).Dial": {7, 7}, + "(*net/http.Transport).CancelRequest": {6, 5}, + "net/http.ErrWriteAfterFlush": {7, 0}, + "net/http.ErrHeaderTooLong": {8, 0}, + "net/http.ErrShortBody": {8, 0}, + "net/http.ErrMissingContentLength": {8, 0}, + "net/http/httputil.ErrPersistEOF": {0, 0}, + "net/http/httputil.ErrClosed": {0, 0}, + "net/http/httputil.ErrPipeline": {0, 0}, + "net/http/httputil.ServerConn": {0, 0}, + "net/http/httputil.NewServerConn": {0, 0}, + "net/http/httputil.ClientConn": {0, 0}, + "net/http/httputil.NewClientConn": {0, 0}, + "net/http/httputil.NewProxyClientConn": {0, 0}, + "(net/http.Request).Cancel": {7, 7}, + "(text/template/parse.PipeNode).Line": {1, 1}, + "(text/template/parse.ActionNode).Line": {1, 1}, + "(text/template/parse.BranchNode).Line": {1, 1}, + "(text/template/parse.TemplateNode).Line": {1, 1}, + "database/sql/driver.ColumnConverter": {9, 9}, + "database/sql/driver.Execer": {8, 8}, + "database/sql/driver.Queryer": {8, 8}, + "(database/sql/driver.Conn).Begin": {8, 8}, + "(database/sql/driver.Stmt).Exec": {8, 8}, + "(database/sql/driver.Stmt).Query": {8, 8}, + "syscall.StringByteSlice": {1, 1}, + "syscall.StringBytePtr": {1, 1}, + "syscall.StringSlicePtr": {1, 1}, + "syscall.StringToUTF16": {1, 1}, + "syscall.StringToUTF16Ptr": {1, 1}, +} diff --git a/vendor/honnef.co/go/tools/functions/concrete.go b/vendor/honnef.co/go/tools/functions/concrete.go new file mode 100644 index 00000000000..932acd03eda --- /dev/null +++ b/vendor/honnef.co/go/tools/functions/concrete.go @@ -0,0 +1,56 @@ +package functions + +import ( + "go/token" + "go/types" + + "honnef.co/go/tools/ssa" +) + +func concreteReturnTypes(fn *ssa.Function) []*types.Tuple { + res := fn.Signature.Results() + if res == nil { + return nil + } + ifaces := make([]bool, res.Len()) + any := false + for i := 0; i < res.Len(); i++ { + _, ifaces[i] = res.At(i).Type().Underlying().(*types.Interface) + any = any || ifaces[i] + } + if !any { + return []*types.Tuple{res} + } + var out []*types.Tuple + for _, block := range fn.Blocks { + if len(block.Instrs) == 0 { + continue + } + ret, ok := block.Instrs[len(block.Instrs)-1].(*ssa.Return) + if !ok { + continue + } + vars := make([]*types.Var, res.Len()) + for i, v := range ret.Results { + var typ types.Type + if !ifaces[i] { + typ = res.At(i).Type() + } else if mi, ok := v.(*ssa.MakeInterface); ok { + // TODO(dh): if mi.X is a function call that returns + // an interface, call concreteReturnTypes on that + // function (or, really, go through Descriptions, + // avoid infinite recursion etc, just like nil error + // detection) + + // TODO(dh): support Phi nodes + typ = mi.X.Type() + } else { + typ = res.At(i).Type() + } + vars[i] = types.NewParam(token.NoPos, nil, "", typ) + } + out = append(out, types.NewTuple(vars...)) + } + // TODO(dh): deduplicate out + return out +} diff --git a/vendor/honnef.co/go/tools/functions/functions.go b/vendor/honnef.co/go/tools/functions/functions.go new file mode 100644 index 00000000000..83940412975 --- /dev/null +++ b/vendor/honnef.co/go/tools/functions/functions.go @@ -0,0 +1,150 @@ +package functions + +import ( + "go/types" + "sync" + + "honnef.co/go/tools/callgraph" + "honnef.co/go/tools/callgraph/static" + "honnef.co/go/tools/ssa" + "honnef.co/go/tools/staticcheck/vrp" +) + +var stdlibDescs = map[string]Description{ + "errors.New": {Pure: true}, + + "fmt.Errorf": {Pure: true}, + "fmt.Sprintf": {Pure: true}, + "fmt.Sprint": {Pure: true}, + + "sort.Reverse": {Pure: true}, + + "strings.Map": {Pure: true}, + "strings.Repeat": {Pure: true}, + "strings.Replace": {Pure: true}, + "strings.Title": {Pure: true}, + "strings.ToLower": {Pure: true}, + "strings.ToLowerSpecial": {Pure: true}, + "strings.ToTitle": {Pure: true}, + "strings.ToTitleSpecial": {Pure: true}, + "strings.ToUpper": {Pure: true}, + "strings.ToUpperSpecial": {Pure: true}, + "strings.Trim": {Pure: true}, + "strings.TrimFunc": {Pure: true}, + "strings.TrimLeft": {Pure: true}, + "strings.TrimLeftFunc": {Pure: true}, + "strings.TrimPrefix": {Pure: true}, + "strings.TrimRight": {Pure: true}, + "strings.TrimRightFunc": {Pure: true}, + "strings.TrimSpace": {Pure: true}, + "strings.TrimSuffix": {Pure: true}, + + "(*net/http.Request).WithContext": {Pure: true}, + + "math/rand.Read": {NilError: true}, + "(*math/rand.Rand).Read": {NilError: true}, +} + +type Description struct { + // The function is known to be pure + Pure bool + // The function is known to be a stub + Stub bool + // The function is known to never return (panics notwithstanding) + Infinite bool + // Variable ranges + Ranges vrp.Ranges + Loops []Loop + // Function returns an error as its last argument, but it is + // always nil + NilError bool + ConcreteReturnTypes []*types.Tuple +} + +type descriptionEntry struct { + ready chan struct{} + result Description +} + +type Descriptions struct { + CallGraph *callgraph.Graph + mu sync.Mutex + cache map[*ssa.Function]*descriptionEntry +} + +func NewDescriptions(prog *ssa.Program) *Descriptions { + return &Descriptions{ + CallGraph: static.CallGraph(prog), + cache: map[*ssa.Function]*descriptionEntry{}, + } +} + +func (d *Descriptions) Get(fn *ssa.Function) Description { + d.mu.Lock() + fd := d.cache[fn] + if fd == nil { + fd = &descriptionEntry{ + ready: make(chan struct{}), + } + d.cache[fn] = fd + d.mu.Unlock() + + { + fd.result = stdlibDescs[fn.RelString(nil)] + fd.result.Pure = fd.result.Pure || d.IsPure(fn) + fd.result.Stub = fd.result.Stub || d.IsStub(fn) + fd.result.Infinite = fd.result.Infinite || !terminates(fn) + fd.result.Ranges = vrp.BuildGraph(fn).Solve() + fd.result.Loops = findLoops(fn) + fd.result.NilError = fd.result.NilError || IsNilError(fn) + fd.result.ConcreteReturnTypes = concreteReturnTypes(fn) + } + + close(fd.ready) + } else { + d.mu.Unlock() + <-fd.ready + } + return fd.result +} + +func IsNilError(fn *ssa.Function) bool { + // TODO(dh): This is very simplistic, as we only look for constant + // nil returns. A more advanced approach would work transitively. + // An even more advanced approach would be context-aware and + // determine nil errors based on inputs (e.g. io.WriteString to a + // bytes.Buffer will always return nil, but an io.WriteString to + // an os.File might not). Similarly, an os.File opened for reading + // won't error on Close, but other files will. + res := fn.Signature.Results() + if res.Len() == 0 { + return false + } + last := res.At(res.Len() - 1) + if types.TypeString(last.Type(), nil) != "error" { + return false + } + + if fn.Blocks == nil { + return false + } + for _, block := range fn.Blocks { + if len(block.Instrs) == 0 { + continue + } + ins := block.Instrs[len(block.Instrs)-1] + ret, ok := ins.(*ssa.Return) + if !ok { + continue + } + v := ret.Results[len(ret.Results)-1] + c, ok := v.(*ssa.Const) + if !ok { + return false + } + if !c.IsNil() { + return false + } + } + return true +} diff --git a/vendor/honnef.co/go/tools/functions/loops.go b/vendor/honnef.co/go/tools/functions/loops.go new file mode 100644 index 00000000000..63011cf3ef6 --- /dev/null +++ b/vendor/honnef.co/go/tools/functions/loops.go @@ -0,0 +1,50 @@ +package functions + +import "honnef.co/go/tools/ssa" + +type Loop map[*ssa.BasicBlock]bool + +func findLoops(fn *ssa.Function) []Loop { + if fn.Blocks == nil { + return nil + } + tree := fn.DomPreorder() + var sets []Loop + for _, h := range tree { + for _, n := range h.Preds { + if !h.Dominates(n) { + continue + } + // n is a back-edge to h + // h is the loop header + if n == h { + sets = append(sets, Loop{n: true}) + continue + } + set := Loop{h: true, n: true} + for _, b := range allPredsBut(n, h, nil) { + set[b] = true + } + sets = append(sets, set) + } + } + return sets +} + +func allPredsBut(b, but *ssa.BasicBlock, list []*ssa.BasicBlock) []*ssa.BasicBlock { +outer: + for _, pred := range b.Preds { + if pred == but { + continue + } + for _, p := range list { + // TODO improve big-o complexity of this function + if pred == p { + continue outer + } + } + list = append(list, pred) + list = allPredsBut(pred, but, list) + } + return list +} diff --git a/vendor/honnef.co/go/tools/functions/pure.go b/vendor/honnef.co/go/tools/functions/pure.go new file mode 100644 index 00000000000..7028eb8c649 --- /dev/null +++ b/vendor/honnef.co/go/tools/functions/pure.go @@ -0,0 +1,123 @@ +package functions + +import ( + "go/token" + "go/types" + + "honnef.co/go/tools/callgraph" + "honnef.co/go/tools/lint/lintdsl" + "honnef.co/go/tools/ssa" +) + +// IsStub reports whether a function is a stub. A function is +// considered a stub if it has no instructions or exactly one +// instruction, which must be either returning only constant values or +// a panic. +func (d *Descriptions) IsStub(fn *ssa.Function) bool { + if len(fn.Blocks) == 0 { + return true + } + if len(fn.Blocks) > 1 { + return false + } + instrs := lintdsl.FilterDebug(fn.Blocks[0].Instrs) + if len(instrs) != 1 { + return false + } + + switch instrs[0].(type) { + case *ssa.Return: + // Since this is the only instruction, the return value must + // be a constant. We consider all constants as stubs, not just + // the zero value. This does not, unfortunately, cover zero + // initialised structs, as these cause additional + // instructions. + return true + case *ssa.Panic: + return true + default: + return false + } +} + +func (d *Descriptions) IsPure(fn *ssa.Function) bool { + if fn.Signature.Results().Len() == 0 { + // A function with no return values is empty or is doing some + // work we cannot see (for example because of build tags); + // don't consider it pure. + return false + } + + for _, param := range fn.Params { + if _, ok := param.Type().Underlying().(*types.Basic); !ok { + return false + } + } + + if fn.Blocks == nil { + return false + } + checkCall := func(common *ssa.CallCommon) bool { + if common.IsInvoke() { + return false + } + builtin, ok := common.Value.(*ssa.Builtin) + if !ok { + if common.StaticCallee() != fn { + if common.StaticCallee() == nil { + return false + } + // TODO(dh): ideally, IsPure wouldn't be responsible + // for avoiding infinite recursion, but + // FunctionDescriptions would be. + node := d.CallGraph.CreateNode(common.StaticCallee()) + if callgraph.PathSearch(node, func(other *callgraph.Node) bool { + return other.Func == fn + }) != nil { + return false + } + if !d.Get(common.StaticCallee()).Pure { + return false + } + } + } else { + switch builtin.Name() { + case "len", "cap", "make", "new": + default: + return false + } + } + return true + } + for _, b := range fn.Blocks { + for _, ins := range b.Instrs { + switch ins := ins.(type) { + case *ssa.Call: + if !checkCall(ins.Common()) { + return false + } + case *ssa.Defer: + if !checkCall(&ins.Call) { + return false + } + case *ssa.Select: + return false + case *ssa.Send: + return false + case *ssa.Go: + return false + case *ssa.Panic: + return false + case *ssa.Store: + return false + case *ssa.FieldAddr: + return false + case *ssa.UnOp: + if ins.Op == token.MUL || ins.Op == token.AND { + return false + } + } + } + } + return true +} diff --git a/vendor/honnef.co/go/tools/functions/terminates.go b/vendor/honnef.co/go/tools/functions/terminates.go new file mode 100644 index 00000000000..65f9e16dc99 --- /dev/null +++ b/vendor/honnef.co/go/tools/functions/terminates.go @@ -0,0 +1,24 @@ +package functions + +import "honnef.co/go/tools/ssa" + +// terminates reports whether fn is supposed to return, that is if it +// has at least one theoretic path that returns from the function. +// Explicit panics do not count as terminating. +func terminates(fn *ssa.Function) bool { + if fn.Blocks == nil { + // assuming that a function terminates is the conservative + // choice + return true + } + + for _, block := range fn.Blocks { + if len(block.Instrs) == 0 { + continue + } + if _, ok := block.Instrs[len(block.Instrs)-1].(*ssa.Return); ok { + return true + } + } + return false +} diff --git a/vendor/honnef.co/go/tools/staticcheck/CONTRIBUTING.md b/vendor/honnef.co/go/tools/staticcheck/CONTRIBUTING.md new file mode 100644 index 00000000000..b12c7afc748 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/CONTRIBUTING.md @@ -0,0 +1,15 @@ +# Contributing to staticcheck + +## Before filing an issue: + +### Are you having trouble building staticcheck? + +Check you have the latest version of its dependencies. Run +``` +go get -u honnef.co/go/tools/staticcheck +``` +If you still have problems, consider searching for existing issues before filing a new issue. + +## Before sending a pull request: + +Have you understood the purpose of staticcheck? Make sure to carefully read `README`. diff --git a/vendor/honnef.co/go/tools/staticcheck/buildtag.go b/vendor/honnef.co/go/tools/staticcheck/buildtag.go new file mode 100644 index 00000000000..888d3e9dc05 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/buildtag.go @@ -0,0 +1,21 @@ +package staticcheck + +import ( + "go/ast" + "strings" + + . "honnef.co/go/tools/lint/lintdsl" +) + +func buildTags(f *ast.File) [][]string { + var out [][]string + for _, line := range strings.Split(Preamble(f), "\n") { + if !strings.HasPrefix(line, "+build ") { + continue + } + line = strings.TrimSpace(strings.TrimPrefix(line, "+build ")) + fields := strings.Fields(line) + out = append(out, fields) + } + return out +} diff --git a/vendor/honnef.co/go/tools/staticcheck/lint.go b/vendor/honnef.co/go/tools/staticcheck/lint.go new file mode 100644 index 00000000000..7d03ca7150c --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/lint.go @@ -0,0 +1,2790 @@ +// Package staticcheck contains a linter for Go source code. +package staticcheck // import "honnef.co/go/tools/staticcheck" + +import ( + "fmt" + "go/ast" + "go/constant" + "go/token" + "go/types" + htmltemplate "html/template" + "net/http" + "regexp" + "regexp/syntax" + "sort" + "strconv" + "strings" + "sync" + texttemplate "text/template" + + "honnef.co/go/tools/deprecated" + "honnef.co/go/tools/functions" + "honnef.co/go/tools/internal/sharedcheck" + "honnef.co/go/tools/lint" + . "honnef.co/go/tools/lint/lintdsl" + "honnef.co/go/tools/ssa" + "honnef.co/go/tools/staticcheck/vrp" + + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/loader" +) + +func validRegexp(call *Call) { + arg := call.Args[0] + err := ValidateRegexp(arg.Value) + if err != nil { + arg.Invalid(err.Error()) + } +} + +type runeSlice []rune + +func (rs runeSlice) Len() int { return len(rs) } +func (rs runeSlice) Less(i int, j int) bool { return rs[i] < rs[j] } +func (rs runeSlice) Swap(i int, j int) { rs[i], rs[j] = rs[j], rs[i] } + +func utf8Cutset(call *Call) { + arg := call.Args[1] + if InvalidUTF8(arg.Value) { + arg.Invalid(MsgInvalidUTF8) + } +} + +func uniqueCutset(call *Call) { + arg := call.Args[1] + if !UniqueStringCutset(arg.Value) { + arg.Invalid(MsgNonUniqueCutset) + } +} + +func unmarshalPointer(name string, arg int) CallCheck { + return func(call *Call) { + if !Pointer(call.Args[arg].Value) { + call.Args[arg].Invalid(fmt.Sprintf("%s expects to unmarshal into a pointer, but the provided value is not a pointer", name)) + } + } +} + +func pointlessIntMath(call *Call) { + if ConvertedFromInt(call.Args[0].Value) { + call.Invalid(fmt.Sprintf("calling %s on a converted integer is pointless", CallName(call.Instr.Common()))) + } +} + +func checkValidHostPort(arg int) CallCheck { + return func(call *Call) { + if !ValidHostPort(call.Args[arg].Value) { + call.Args[arg].Invalid(MsgInvalidHostPort) + } + } +} + +var ( + checkRegexpRules = map[string]CallCheck{ + "regexp.MustCompile": validRegexp, + "regexp.Compile": validRegexp, + "regexp.Match": validRegexp, + "regexp.MatchReader": validRegexp, + "regexp.MatchString": validRegexp, + } + + checkTimeParseRules = map[string]CallCheck{ + "time.Parse": func(call *Call) { + arg := call.Args[0] + err := ValidateTimeLayout(arg.Value) + if err != nil { + arg.Invalid(err.Error()) + } + }, + } + + checkEncodingBinaryRules = map[string]CallCheck{ + "encoding/binary.Write": func(call *Call) { + arg := call.Args[2] + if !CanBinaryMarshal(call.Job, arg.Value) { + arg.Invalid(fmt.Sprintf("value of type %s cannot be used with binary.Write", arg.Value.Value.Type())) + } + }, + } + + checkURLsRules = map[string]CallCheck{ + "net/url.Parse": func(call *Call) { + arg := call.Args[0] + err := ValidateURL(arg.Value) + if err != nil { + arg.Invalid(err.Error()) + } + }, + } + + checkSyncPoolValueRules = map[string]CallCheck{ + "(*sync.Pool).Put": func(call *Call) { + arg := call.Args[0] + typ := arg.Value.Value.Type() + if !IsPointerLike(typ) { + arg.Invalid("argument should be pointer-like to avoid allocations") + } + }, + } + + checkRegexpFindAllRules = map[string]CallCheck{ + "(*regexp.Regexp).FindAll": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllIndex": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllString": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllStringIndex": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllStringSubmatch": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllStringSubmatchIndex": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllSubmatch": RepeatZeroTimes("a FindAll method", 1), + "(*regexp.Regexp).FindAllSubmatchIndex": RepeatZeroTimes("a FindAll method", 1), + } + + checkUTF8CutsetRules = map[string]CallCheck{ + "strings.IndexAny": utf8Cutset, + "strings.LastIndexAny": utf8Cutset, + "strings.ContainsAny": utf8Cutset, + "strings.Trim": utf8Cutset, + "strings.TrimLeft": utf8Cutset, + "strings.TrimRight": utf8Cutset, + } + + checkUniqueCutsetRules = map[string]CallCheck{ + "strings.Trim": uniqueCutset, + "strings.TrimLeft": uniqueCutset, + "strings.TrimRight": uniqueCutset, + } + + checkUnmarshalPointerRules = map[string]CallCheck{ + "encoding/xml.Unmarshal": unmarshalPointer("xml.Unmarshal", 1), + "(*encoding/xml.Decoder).Decode": unmarshalPointer("Decode", 0), + "(*encoding/xml.Decoder).DecodeElement": unmarshalPointer("DecodeElement", 0), + "encoding/json.Unmarshal": unmarshalPointer("json.Unmarshal", 1), + "(*encoding/json.Decoder).Decode": unmarshalPointer("Decode", 0), + } + + checkUnbufferedSignalChanRules = map[string]CallCheck{ + "os/signal.Notify": func(call *Call) { + arg := call.Args[0] + if UnbufferedChannel(arg.Value) { + arg.Invalid("the channel used with signal.Notify should be buffered") + } + }, + } + + checkMathIntRules = map[string]CallCheck{ + "math.Ceil": pointlessIntMath, + "math.Floor": pointlessIntMath, + "math.IsNaN": pointlessIntMath, + "math.Trunc": pointlessIntMath, + "math.IsInf": pointlessIntMath, + } + + checkStringsReplaceZeroRules = map[string]CallCheck{ + "strings.Replace": RepeatZeroTimes("strings.Replace", 3), + "bytes.Replace": RepeatZeroTimes("bytes.Replace", 3), + } + + checkListenAddressRules = map[string]CallCheck{ + "net/http.ListenAndServe": checkValidHostPort(0), + "net/http.ListenAndServeTLS": checkValidHostPort(0), + } + + checkBytesEqualIPRules = map[string]CallCheck{ + "bytes.Equal": func(call *Call) { + if ConvertedFrom(call.Args[0].Value, "net.IP") && ConvertedFrom(call.Args[1].Value, "net.IP") { + call.Invalid("use net.IP.Equal to compare net.IPs, not bytes.Equal") + } + }, + } + + checkRegexpMatchLoopRules = map[string]CallCheck{ + "regexp.Match": loopedRegexp("regexp.Match"), + "regexp.MatchReader": loopedRegexp("regexp.MatchReader"), + "regexp.MatchString": loopedRegexp("regexp.MatchString"), + } +) + +type Checker struct { + CheckGenerated bool + funcDescs *functions.Descriptions + deprecatedObjs map[types.Object]string +} + +func NewChecker() *Checker { + return &Checker{} +} + +func (*Checker) Name() string { return "staticcheck" } +func (*Checker) Prefix() string { return "SA" } + +func (c *Checker) Funcs() map[string]lint.Func { + return map[string]lint.Func{ + "SA1000": c.callChecker(checkRegexpRules), + "SA1001": c.CheckTemplate, + "SA1002": c.callChecker(checkTimeParseRules), + "SA1003": c.callChecker(checkEncodingBinaryRules), + "SA1004": c.CheckTimeSleepConstant, + "SA1005": c.CheckExec, + "SA1006": c.CheckUnsafePrintf, + "SA1007": c.callChecker(checkURLsRules), + "SA1008": c.CheckCanonicalHeaderKey, + "SA1009": nil, + "SA1010": c.callChecker(checkRegexpFindAllRules), + "SA1011": c.callChecker(checkUTF8CutsetRules), + "SA1012": c.CheckNilContext, + "SA1013": c.CheckSeeker, + "SA1014": c.callChecker(checkUnmarshalPointerRules), + "SA1015": c.CheckLeakyTimeTick, + "SA1016": c.CheckUntrappableSignal, + "SA1017": c.callChecker(checkUnbufferedSignalChanRules), + "SA1018": c.callChecker(checkStringsReplaceZeroRules), + "SA1019": c.CheckDeprecated, + "SA1020": c.callChecker(checkListenAddressRules), + "SA1021": c.callChecker(checkBytesEqualIPRules), + "SA1022": nil, + "SA1023": c.CheckWriterBufferModified, + "SA1024": c.callChecker(checkUniqueCutsetRules), + + "SA2000": c.CheckWaitgroupAdd, + "SA2001": c.CheckEmptyCriticalSection, + "SA2002": c.CheckConcurrentTesting, + "SA2003": c.CheckDeferLock, + + "SA3000": c.CheckTestMainExit, + "SA3001": c.CheckBenchmarkN, + + "SA4000": c.CheckLhsRhsIdentical, + "SA4001": c.CheckIneffectiveCopy, + "SA4002": c.CheckDiffSizeComparison, + "SA4003": c.CheckUnsignedComparison, + "SA4004": c.CheckIneffectiveLoop, + "SA4005": nil, + "SA4006": c.CheckUnreadVariableValues, + // "SA4007": c.CheckPredeterminedBooleanExprs, + "SA4007": nil, + "SA4008": c.CheckLoopCondition, + "SA4009": c.CheckArgOverwritten, + "SA4010": c.CheckIneffectiveAppend, + "SA4011": c.CheckScopedBreak, + "SA4012": c.CheckNaNComparison, + "SA4013": c.CheckDoubleNegation, + "SA4014": c.CheckRepeatedIfElse, + "SA4015": c.callChecker(checkMathIntRules), + "SA4016": c.CheckSillyBitwiseOps, + "SA4017": c.CheckPureFunctions, + "SA4018": c.CheckSelfAssignment, + "SA4019": c.CheckDuplicateBuildConstraints, + + "SA5000": c.CheckNilMaps, + "SA5001": c.CheckEarlyDefer, + "SA5002": c.CheckInfiniteEmptyLoop, + "SA5003": c.CheckDeferInInfiniteLoop, + "SA5004": c.CheckLoopEmptyDefault, + "SA5005": c.CheckCyclicFinalizer, + // "SA5006": c.CheckSliceOutOfBounds, + "SA5007": c.CheckInfiniteRecursion, + + "SA6000": c.callChecker(checkRegexpMatchLoopRules), + "SA6001": c.CheckMapBytesKey, + "SA6002": c.callChecker(checkSyncPoolValueRules), + "SA6003": c.CheckRangeStringRunes, + "SA6004": c.CheckSillyRegexp, + + "SA9000": nil, + "SA9001": c.CheckDubiousDeferInChannelRangeLoop, + "SA9002": c.CheckNonOctalFileMode, + "SA9003": c.CheckEmptyBranch, + "SA9004": c.CheckMissingEnumTypesInDeclaration, + } +} + +func (c *Checker) filterGenerated(files []*ast.File) []*ast.File { + if c.CheckGenerated { + return files + } + var out []*ast.File + for _, f := range files { + if !IsGenerated(f) { + out = append(out, f) + } + } + return out +} + +func (c *Checker) findDeprecated(prog *lint.Program) { + var docs []*ast.CommentGroup + var names []*ast.Ident + + doDocs := func(pkginfo *loader.PackageInfo, names []*ast.Ident, docs []*ast.CommentGroup) { + var alt string + for _, doc := range docs { + if doc == nil { + continue + } + parts := strings.Split(doc.Text(), "\n\n") + last := parts[len(parts)-1] + if !strings.HasPrefix(last, "Deprecated: ") { + continue + } + alt = last[len("Deprecated: "):] + alt = strings.Replace(alt, "\n", " ", -1) + break + } + if alt == "" { + return + } + + for _, name := range names { + obj := pkginfo.ObjectOf(name) + c.deprecatedObjs[obj] = alt + } + } + + for _, pkginfo := range prog.Prog.AllPackages { + for _, f := range pkginfo.Files { + fn := func(node ast.Node) bool { + if node == nil { + return true + } + var ret bool + switch node := node.(type) { + case *ast.GenDecl: + switch node.Tok { + case token.TYPE, token.CONST, token.VAR: + docs = append(docs, node.Doc) + return true + default: + return false + } + case *ast.FuncDecl: + docs = append(docs, node.Doc) + names = []*ast.Ident{node.Name} + ret = false + case *ast.TypeSpec: + docs = append(docs, node.Doc) + names = []*ast.Ident{node.Name} + ret = true + case *ast.ValueSpec: + docs = append(docs, node.Doc) + names = node.Names + ret = false + case *ast.File: + return true + case *ast.StructType: + for _, field := range node.Fields.List { + doDocs(pkginfo, field.Names, []*ast.CommentGroup{field.Doc}) + } + return false + case *ast.InterfaceType: + for _, field := range node.Methods.List { + doDocs(pkginfo, field.Names, []*ast.CommentGroup{field.Doc}) + } + return false + default: + return false + } + if len(names) == 0 || len(docs) == 0 { + return ret + } + doDocs(pkginfo, names, docs) + + docs = docs[:0] + names = nil + return ret + } + ast.Inspect(f, fn) + } + } +} + +func (c *Checker) Init(prog *lint.Program) { + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.funcDescs = functions.NewDescriptions(prog.SSA) + for _, fn := range prog.AllFunctions { + if fn.Blocks != nil { + applyStdlibKnowledge(fn) + ssa.OptimizeBlocks(fn) + } + } + wg.Done() + }() + + go func() { + c.deprecatedObjs = map[types.Object]string{} + c.findDeprecated(prog) + wg.Done() + }() + + wg.Wait() +} + +func (c *Checker) isInLoop(b *ssa.BasicBlock) bool { + sets := c.funcDescs.Get(b.Parent()).Loops + for _, set := range sets { + if set[b] { + return true + } + } + return false +} + +func applyStdlibKnowledge(fn *ssa.Function) { + if len(fn.Blocks) == 0 { + return + } + + // comma-ok receiving from a time.Tick channel will never return + // ok == false, so any branching on the value of ok can be + // replaced with an unconditional jump. This will primarily match + // `for range time.Tick(x)` loops, but it can also match + // user-written code. + for _, block := range fn.Blocks { + if len(block.Instrs) < 3 { + continue + } + if len(block.Succs) != 2 { + continue + } + var instrs []*ssa.Instruction + for i, ins := range block.Instrs { + if _, ok := ins.(*ssa.DebugRef); ok { + continue + } + instrs = append(instrs, &block.Instrs[i]) + } + + for i, ins := range instrs { + unop, ok := (*ins).(*ssa.UnOp) + if !ok || unop.Op != token.ARROW { + continue + } + call, ok := unop.X.(*ssa.Call) + if !ok { + continue + } + if !IsCallTo(call.Common(), "time.Tick") { + continue + } + ex, ok := (*instrs[i+1]).(*ssa.Extract) + if !ok || ex.Tuple != unop || ex.Index != 1 { + continue + } + + ifstmt, ok := (*instrs[i+2]).(*ssa.If) + if !ok || ifstmt.Cond != ex { + continue + } + + *instrs[i+2] = ssa.NewJump(block) + succ := block.Succs[1] + block.Succs = block.Succs[0:1] + succ.RemovePred(block) + } + } +} + +func hasType(j *lint.Job, expr ast.Expr, name string) bool { + T := TypeOf(j, expr) + return IsType(T, name) +} + +func (c *Checker) CheckUntrappableSignal(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + if !IsCallToAnyAST(j, call, + "os/signal.Ignore", "os/signal.Notify", "os/signal.Reset") { + return true + } + for _, arg := range call.Args { + if conv, ok := arg.(*ast.CallExpr); ok && isName(j, conv.Fun, "os.Signal") { + arg = conv.Args[0] + } + + if isName(j, arg, "os.Kill") || isName(j, arg, "syscall.SIGKILL") { + j.Errorf(arg, "%s cannot be trapped (did you mean syscall.SIGTERM?)", Render(j, arg)) + } + if isName(j, arg, "syscall.SIGSTOP") { + j.Errorf(arg, "%s signal cannot be trapped", Render(j, arg)) + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckTemplate(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + var kind string + if IsCallToAST(j, call, "(*text/template.Template).Parse") { + kind = "text" + } else if IsCallToAST(j, call, "(*html/template.Template).Parse") { + kind = "html" + } else { + return true + } + sel := call.Fun.(*ast.SelectorExpr) + if !IsCallToAST(j, sel.X, "text/template.New") && + !IsCallToAST(j, sel.X, "html/template.New") { + // TODO(dh): this is a cheap workaround for templates with + // different delims. A better solution with less false + // negatives would use data flow analysis to see where the + // template comes from and where it has been + return true + } + s, ok := ExprToString(j, call.Args[0]) + if !ok { + return true + } + var err error + switch kind { + case "text": + _, err = texttemplate.New("").Parse(s) + case "html": + _, err = htmltemplate.New("").Parse(s) + } + if err != nil { + // TODO(dominikh): whitelist other parse errors, if any + if strings.Contains(err.Error(), "unexpected") { + j.Errorf(call.Args[0], "%s", err) + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckTimeSleepConstant(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + if !IsCallToAST(j, call, "time.Sleep") { + return true + } + lit, ok := call.Args[0].(*ast.BasicLit) + if !ok { + return true + } + n, err := strconv.Atoi(lit.Value) + if err != nil { + return true + } + if n == 0 || n > 120 { + // time.Sleep(0) is a seldom used pattern in concurrency + // tests. >120 might be intentional. 120 was chosen + // because the user could've meant 2 minutes. + return true + } + recommendation := "time.Sleep(time.Nanosecond)" + if n != 1 { + recommendation = fmt.Sprintf("time.Sleep(%d * time.Nanosecond)", n) + } + j.Errorf(call.Args[0], "sleeping for %d nanoseconds is probably a bug. Be explicit if it isn't: %s", n, recommendation) + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckWaitgroupAdd(j *lint.Job) { + fn := func(node ast.Node) bool { + g, ok := node.(*ast.GoStmt) + if !ok { + return true + } + fun, ok := g.Call.Fun.(*ast.FuncLit) + if !ok { + return true + } + if len(fun.Body.List) == 0 { + return true + } + stmt, ok := fun.Body.List[0].(*ast.ExprStmt) + if !ok { + return true + } + call, ok := stmt.X.(*ast.CallExpr) + if !ok { + return true + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + fn, ok := ObjectOf(j, sel.Sel).(*types.Func) + if !ok { + return true + } + if fn.FullName() == "(*sync.WaitGroup).Add" { + j.Errorf(sel, "should call %s before starting the goroutine to avoid a race", + Render(j, stmt)) + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckInfiniteEmptyLoop(j *lint.Job) { + fn := func(node ast.Node) bool { + loop, ok := node.(*ast.ForStmt) + if !ok || len(loop.Body.List) != 0 || loop.Post != nil { + return true + } + + if loop.Init != nil { + // TODO(dh): this isn't strictly necessary, it just makes + // the check easier. + return true + } + // An empty loop is bad news in two cases: 1) The loop has no + // condition. In that case, it's just a loop that spins + // forever and as fast as it can, keeping a core busy. 2) The + // loop condition only consists of variable or field reads and + // operators on those. The only way those could change their + // value is with unsynchronised access, which constitutes a + // data race. + // + // If the condition contains any function calls, its behaviour + // is dynamic and the loop might terminate. Similarly for + // channel receives. + + if loop.Cond != nil && hasSideEffects(loop.Cond) { + return true + } + + j.Errorf(loop, "this loop will spin, using 100%% CPU") + if loop.Cond != nil { + j.Errorf(loop, "loop condition never changes or has a race condition") + } + + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckDeferInInfiniteLoop(j *lint.Job) { + fn := func(node ast.Node) bool { + mightExit := false + var defers []ast.Stmt + loop, ok := node.(*ast.ForStmt) + if !ok || loop.Cond != nil { + return true + } + fn2 := func(node ast.Node) bool { + switch stmt := node.(type) { + case *ast.ReturnStmt: + mightExit = true + case *ast.BranchStmt: + // TODO(dominikh): if this sees a break in a switch or + // select, it doesn't check if it breaks the loop or + // just the select/switch. This causes some false + // negatives. + if stmt.Tok == token.BREAK { + mightExit = true + } + case *ast.DeferStmt: + defers = append(defers, stmt) + case *ast.FuncLit: + // Don't look into function bodies + return false + } + return true + } + ast.Inspect(loop.Body, fn2) + if mightExit { + return true + } + for _, stmt := range defers { + j.Errorf(stmt, "defers in this infinite loop will never run") + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckDubiousDeferInChannelRangeLoop(j *lint.Job) { + fn := func(node ast.Node) bool { + loop, ok := node.(*ast.RangeStmt) + if !ok { + return true + } + typ := TypeOf(j, loop.X) + _, ok = typ.Underlying().(*types.Chan) + if !ok { + return true + } + fn2 := func(node ast.Node) bool { + switch stmt := node.(type) { + case *ast.DeferStmt: + j.Errorf(stmt, "defers in this range loop won't run unless the channel gets closed") + case *ast.FuncLit: + // Don't look into function bodies + return false + } + return true + } + ast.Inspect(loop.Body, fn2) + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckTestMainExit(j *lint.Job) { + fn := func(node ast.Node) bool { + if !isTestMain(j, node) { + return true + } + + arg := ObjectOf(j, node.(*ast.FuncDecl).Type.Params.List[0].Names[0]) + callsRun := false + fn2 := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + if arg != ObjectOf(j, ident) { + return true + } + if sel.Sel.Name == "Run" { + callsRun = true + return false + } + return true + } + ast.Inspect(node.(*ast.FuncDecl).Body, fn2) + + callsExit := false + fn3 := func(node ast.Node) bool { + if IsCallToAST(j, node, "os.Exit") { + callsExit = true + return false + } + return true + } + ast.Inspect(node.(*ast.FuncDecl).Body, fn3) + if !callsExit && callsRun { + j.Errorf(node, "TestMain should call os.Exit to set exit code") + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func isTestMain(j *lint.Job, node ast.Node) bool { + decl, ok := node.(*ast.FuncDecl) + if !ok { + return false + } + if decl.Name.Name != "TestMain" { + return false + } + if len(decl.Type.Params.List) != 1 { + return false + } + arg := decl.Type.Params.List[0] + if len(arg.Names) != 1 { + return false + } + return IsOfType(j, arg.Type, "*testing.M") +} + +func (c *Checker) CheckExec(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + if !IsCallToAST(j, call, "os/exec.Command") { + return true + } + val, ok := ExprToString(j, call.Args[0]) + if !ok { + return true + } + if !strings.Contains(val, " ") || strings.Contains(val, `\`) || strings.Contains(val, "/") { + return true + } + j.Errorf(call.Args[0], "first argument to exec.Command looks like a shell command, but a program name or path are expected") + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckLoopEmptyDefault(j *lint.Job) { + fn := func(node ast.Node) bool { + loop, ok := node.(*ast.ForStmt) + if !ok || len(loop.Body.List) != 1 || loop.Cond != nil || loop.Init != nil { + return true + } + sel, ok := loop.Body.List[0].(*ast.SelectStmt) + if !ok { + return true + } + for _, c := range sel.Body.List { + if comm, ok := c.(*ast.CommClause); ok && comm.Comm == nil && len(comm.Body) == 0 { + j.Errorf(comm, "should not have an empty default case in a for+select loop. The loop will spin.") + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckLhsRhsIdentical(j *lint.Job) { + fn := func(node ast.Node) bool { + op, ok := node.(*ast.BinaryExpr) + if !ok { + return true + } + switch op.Op { + case token.EQL, token.NEQ: + if basic, ok := TypeOf(j, op.X).(*types.Basic); ok { + if kind := basic.Kind(); kind == types.Float32 || kind == types.Float64 { + // f == f and f != f might be used to check for NaN + return true + } + } + case token.SUB, token.QUO, token.AND, token.REM, token.OR, token.XOR, token.AND_NOT, + token.LAND, token.LOR, token.LSS, token.GTR, token.LEQ, token.GEQ: + default: + // For some ops, such as + and *, it can make sense to + // have identical operands + return true + } + + if Render(j, op.X) != Render(j, op.Y) { + return true + } + j.Errorf(op, "identical expressions on the left and right side of the '%s' operator", op.Op) + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckScopedBreak(j *lint.Job) { + fn := func(node ast.Node) bool { + var body *ast.BlockStmt + switch node := node.(type) { + case *ast.ForStmt: + body = node.Body + case *ast.RangeStmt: + body = node.Body + default: + return true + } + for _, stmt := range body.List { + var blocks [][]ast.Stmt + switch stmt := stmt.(type) { + case *ast.SwitchStmt: + for _, c := range stmt.Body.List { + blocks = append(blocks, c.(*ast.CaseClause).Body) + } + case *ast.SelectStmt: + for _, c := range stmt.Body.List { + blocks = append(blocks, c.(*ast.CommClause).Body) + } + default: + continue + } + + for _, body := range blocks { + if len(body) == 0 { + continue + } + lasts := []ast.Stmt{body[len(body)-1]} + // TODO(dh): unfold all levels of nested block + // statements, not just a single level if statement + if ifs, ok := lasts[0].(*ast.IfStmt); ok { + if len(ifs.Body.List) == 0 { + continue + } + lasts[0] = ifs.Body.List[len(ifs.Body.List)-1] + + if block, ok := ifs.Else.(*ast.BlockStmt); ok { + if len(block.List) != 0 { + lasts = append(lasts, block.List[len(block.List)-1]) + } + } + } + for _, last := range lasts { + branch, ok := last.(*ast.BranchStmt) + if !ok || branch.Tok != token.BREAK || branch.Label != nil { + continue + } + j.Errorf(branch, "ineffective break statement. Did you mean to break out of the outer loop?") + } + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckUnsafePrintf(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + if !IsCallToAnyAST(j, call, "fmt.Printf", "fmt.Sprintf", "log.Printf") { + return true + } + if len(call.Args) != 1 { + return true + } + switch call.Args[0].(type) { + case *ast.CallExpr, *ast.Ident: + default: + return true + } + j.Errorf(call.Args[0], "printf-style function with dynamic first argument and no further arguments should use print-style function instead") + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckEarlyDefer(j *lint.Job) { + fn := func(node ast.Node) bool { + block, ok := node.(*ast.BlockStmt) + if !ok { + return true + } + if len(block.List) < 2 { + return true + } + for i, stmt := range block.List { + if i == len(block.List)-1 { + break + } + assign, ok := stmt.(*ast.AssignStmt) + if !ok { + continue + } + if len(assign.Rhs) != 1 { + continue + } + if len(assign.Lhs) < 2 { + continue + } + if lhs, ok := assign.Lhs[len(assign.Lhs)-1].(*ast.Ident); ok && lhs.Name == "_" { + continue + } + call, ok := assign.Rhs[0].(*ast.CallExpr) + if !ok { + continue + } + sig, ok := TypeOf(j, call.Fun).(*types.Signature) + if !ok { + continue + } + if sig.Results().Len() < 2 { + continue + } + last := sig.Results().At(sig.Results().Len() - 1) + // FIXME(dh): check that it's error from universe, not + // another type of the same name + if last.Type().String() != "error" { + continue + } + lhs, ok := assign.Lhs[0].(*ast.Ident) + if !ok { + continue + } + def, ok := block.List[i+1].(*ast.DeferStmt) + if !ok { + continue + } + sel, ok := def.Call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + ident, ok := selectorX(sel).(*ast.Ident) + if !ok { + continue + } + if ident.Obj != lhs.Obj { + continue + } + if sel.Sel.Name != "Close" { + continue + } + j.Errorf(def, "should check returned error before deferring %s", Render(j, def.Call)) + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func selectorX(sel *ast.SelectorExpr) ast.Node { + switch x := sel.X.(type) { + case *ast.SelectorExpr: + return selectorX(x) + default: + return x + } +} + +func (c *Checker) CheckEmptyCriticalSection(j *lint.Job) { + // Initially it might seem like this check would be easier to + // implement in SSA. After all, we're only checking for two + // consecutive method calls. In reality, however, there may be any + // number of other instructions between the lock and unlock, while + // still constituting an empty critical section. For example, + // given `m.x().Lock(); m.x().Unlock()`, there will be a call to + // x(). In the AST-based approach, this has a tiny potential for a + // false positive (the second call to x might be doing work that + // is protected by the mutex). In an SSA-based approach, however, + // it would miss a lot of real bugs. + + mutexParams := func(s ast.Stmt) (x ast.Expr, funcName string, ok bool) { + expr, ok := s.(*ast.ExprStmt) + if !ok { + return nil, "", false + } + call, ok := expr.X.(*ast.CallExpr) + if !ok { + return nil, "", false + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return nil, "", false + } + + fn, ok := ObjectOf(j, sel.Sel).(*types.Func) + if !ok { + return nil, "", false + } + sig := fn.Type().(*types.Signature) + if sig.Params().Len() != 0 || sig.Results().Len() != 0 { + return nil, "", false + } + + return sel.X, fn.Name(), true + } + + fn := func(node ast.Node) bool { + block, ok := node.(*ast.BlockStmt) + if !ok { + return true + } + if len(block.List) < 2 { + return true + } + for i := range block.List[:len(block.List)-1] { + sel1, method1, ok1 := mutexParams(block.List[i]) + sel2, method2, ok2 := mutexParams(block.List[i+1]) + + if !ok1 || !ok2 || Render(j, sel1) != Render(j, sel2) { + continue + } + if (method1 == "Lock" && method2 == "Unlock") || + (method1 == "RLock" && method2 == "RUnlock") { + j.Errorf(block.List[i+1], "empty critical section") + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +// cgo produces code like fn(&*_Cvar_kSomeCallbacks) which we don't +// want to flag. +var cgoIdent = regexp.MustCompile(`^_C(func|var)_.+$`) + +func (c *Checker) CheckIneffectiveCopy(j *lint.Job) { + fn := func(node ast.Node) bool { + if unary, ok := node.(*ast.UnaryExpr); ok { + if star, ok := unary.X.(*ast.StarExpr); ok && unary.Op == token.AND { + ident, ok := star.X.(*ast.Ident) + if !ok || !cgoIdent.MatchString(ident.Name) { + j.Errorf(unary, "&*x will be simplified to x. It will not copy x.") + } + } + } + + if star, ok := node.(*ast.StarExpr); ok { + if unary, ok := star.X.(*ast.UnaryExpr); ok && unary.Op == token.AND { + j.Errorf(star, "*&x will be simplified to x. It will not copy x.") + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckDiffSizeComparison(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, b := range ssafn.Blocks { + for _, ins := range b.Instrs { + binop, ok := ins.(*ssa.BinOp) + if !ok { + continue + } + if binop.Op != token.EQL && binop.Op != token.NEQ { + continue + } + _, ok1 := binop.X.(*ssa.Slice) + _, ok2 := binop.Y.(*ssa.Slice) + if !ok1 && !ok2 { + continue + } + r := c.funcDescs.Get(ssafn).Ranges + r1, ok1 := r.Get(binop.X).(vrp.StringInterval) + r2, ok2 := r.Get(binop.Y).(vrp.StringInterval) + if !ok1 || !ok2 { + continue + } + if r1.Length.Intersection(r2.Length).Empty() { + j.Errorf(binop, "comparing strings of different sizes for equality will always return false") + } + } + } + } +} + +func (c *Checker) CheckCanonicalHeaderKey(j *lint.Job) { + fn := func(node ast.Node) bool { + assign, ok := node.(*ast.AssignStmt) + if ok { + // TODO(dh): This risks missing some Header reads, for + // example in `h1["foo"] = h2["foo"]` – these edge + // cases are probably rare enough to ignore for now. + for _, expr := range assign.Lhs { + op, ok := expr.(*ast.IndexExpr) + if !ok { + continue + } + if hasType(j, op.X, "net/http.Header") { + return false + } + } + return true + } + op, ok := node.(*ast.IndexExpr) + if !ok { + return true + } + if !hasType(j, op.X, "net/http.Header") { + return true + } + s, ok := ExprToString(j, op.Index) + if !ok { + return true + } + if s == http.CanonicalHeaderKey(s) { + return true + } + j.Errorf(op, "keys in http.Header are canonicalized, %q is not canonical; fix the constant or use http.CanonicalHeaderKey", s) + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckBenchmarkN(j *lint.Job) { + fn := func(node ast.Node) bool { + assign, ok := node.(*ast.AssignStmt) + if !ok { + return true + } + if len(assign.Lhs) != 1 || len(assign.Rhs) != 1 { + return true + } + sel, ok := assign.Lhs[0].(*ast.SelectorExpr) + if !ok { + return true + } + if sel.Sel.Name != "N" { + return true + } + if !hasType(j, sel.X, "*testing.B") { + return true + } + j.Errorf(assign, "should not assign to %s", Render(j, sel)) + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckUnreadVariableValues(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + if IsExample(ssafn) { + continue + } + node := ssafn.Syntax() + if node == nil { + continue + } + + ast.Inspect(node, func(node ast.Node) bool { + assign, ok := node.(*ast.AssignStmt) + if !ok { + return true + } + if len(assign.Lhs) > 1 && len(assign.Rhs) == 1 { + // Either a function call with multiple return values, + // or a comma-ok assignment + + val, _ := ssafn.ValueForExpr(assign.Rhs[0]) + if val == nil { + return true + } + refs := val.Referrers() + if refs == nil { + return true + } + for _, ref := range *refs { + ex, ok := ref.(*ssa.Extract) + if !ok { + continue + } + exrefs := ex.Referrers() + if exrefs == nil { + continue + } + if len(FilterDebug(*exrefs)) == 0 { + lhs := assign.Lhs[ex.Index] + if ident, ok := lhs.(*ast.Ident); !ok || ok && ident.Name == "_" { + continue + } + j.Errorf(lhs, "this value of %s is never used", lhs) + } + } + return true + } + for i, lhs := range assign.Lhs { + rhs := assign.Rhs[i] + if ident, ok := lhs.(*ast.Ident); !ok || ok && ident.Name == "_" { + continue + } + val, _ := ssafn.ValueForExpr(rhs) + if val == nil { + continue + } + + refs := val.Referrers() + if refs == nil { + // TODO investigate why refs can be nil + return true + } + if len(FilterDebug(*refs)) == 0 { + j.Errorf(lhs, "this value of %s is never used", lhs) + } + } + return true + }) + } +} + +func (c *Checker) CheckPredeterminedBooleanExprs(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + ssabinop, ok := ins.(*ssa.BinOp) + if !ok { + continue + } + switch ssabinop.Op { + case token.GTR, token.LSS, token.EQL, token.NEQ, token.LEQ, token.GEQ: + default: + continue + } + + xs, ok1 := consts(ssabinop.X, nil, nil) + ys, ok2 := consts(ssabinop.Y, nil, nil) + if !ok1 || !ok2 || len(xs) == 0 || len(ys) == 0 { + continue + } + + trues := 0 + for _, x := range xs { + for _, y := range ys { + if x.Value == nil { + if y.Value == nil { + trues++ + } + continue + } + if constant.Compare(x.Value, ssabinop.Op, y.Value) { + trues++ + } + } + } + b := trues != 0 + if trues == 0 || trues == len(xs)*len(ys) { + j.Errorf(ssabinop, "binary expression is always %t for all possible values (%s %s %s)", + b, xs, ssabinop.Op, ys) + } + } + } + } +} + +func (c *Checker) CheckNilMaps(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + mu, ok := ins.(*ssa.MapUpdate) + if !ok { + continue + } + c, ok := mu.Map.(*ssa.Const) + if !ok { + continue + } + if c.Value != nil { + continue + } + j.Errorf(mu, "assignment to nil map") + } + } + } +} + +func (c *Checker) CheckUnsignedComparison(j *lint.Job) { + fn := func(node ast.Node) bool { + expr, ok := node.(*ast.BinaryExpr) + if !ok { + return true + } + tx := TypeOf(j, expr.X) + basic, ok := tx.Underlying().(*types.Basic) + if !ok { + return true + } + if (basic.Info() & types.IsUnsigned) == 0 { + return true + } + lit, ok := expr.Y.(*ast.BasicLit) + if !ok || lit.Value != "0" { + return true + } + switch expr.Op { + case token.GEQ: + j.Errorf(expr, "unsigned values are always >= 0") + case token.LSS: + j.Errorf(expr, "unsigned values are never < 0") + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func consts(val ssa.Value, out []*ssa.Const, visitedPhis map[string]bool) ([]*ssa.Const, bool) { + if visitedPhis == nil { + visitedPhis = map[string]bool{} + } + var ok bool + switch val := val.(type) { + case *ssa.Phi: + if visitedPhis[val.Name()] { + break + } + visitedPhis[val.Name()] = true + vals := val.Operands(nil) + for _, phival := range vals { + out, ok = consts(*phival, out, visitedPhis) + if !ok { + return nil, false + } + } + case *ssa.Const: + out = append(out, val) + case *ssa.Convert: + out, ok = consts(val.X, out, visitedPhis) + if !ok { + return nil, false + } + default: + return nil, false + } + if len(out) < 2 { + return out, true + } + uniq := []*ssa.Const{out[0]} + for _, val := range out[1:] { + if val.Value == uniq[len(uniq)-1].Value { + continue + } + uniq = append(uniq, val) + } + return uniq, true +} + +func (c *Checker) CheckLoopCondition(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + fn := func(node ast.Node) bool { + loop, ok := node.(*ast.ForStmt) + if !ok { + return true + } + if loop.Init == nil || loop.Cond == nil || loop.Post == nil { + return true + } + init, ok := loop.Init.(*ast.AssignStmt) + if !ok || len(init.Lhs) != 1 || len(init.Rhs) != 1 { + return true + } + cond, ok := loop.Cond.(*ast.BinaryExpr) + if !ok { + return true + } + x, ok := cond.X.(*ast.Ident) + if !ok { + return true + } + lhs, ok := init.Lhs[0].(*ast.Ident) + if !ok { + return true + } + if x.Obj != lhs.Obj { + return true + } + if _, ok := loop.Post.(*ast.IncDecStmt); !ok { + return true + } + + v, isAddr := ssafn.ValueForExpr(cond.X) + if v == nil || isAddr { + return true + } + switch v := v.(type) { + case *ssa.Phi: + ops := v.Operands(nil) + if len(ops) != 2 { + return true + } + _, ok := (*ops[0]).(*ssa.Const) + if !ok { + return true + } + sigma, ok := (*ops[1]).(*ssa.Sigma) + if !ok { + return true + } + if sigma.X != v { + return true + } + case *ssa.UnOp: + return true + } + j.Errorf(cond, "variable in loop condition never changes") + + return true + } + Inspect(ssafn.Syntax(), fn) + } +} + +func (c *Checker) CheckArgOverwritten(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + fn := func(node ast.Node) bool { + var typ *ast.FuncType + var body *ast.BlockStmt + switch fn := node.(type) { + case *ast.FuncDecl: + typ = fn.Type + body = fn.Body + case *ast.FuncLit: + typ = fn.Type + body = fn.Body + } + if body == nil { + return true + } + if len(typ.Params.List) == 0 { + return true + } + for _, field := range typ.Params.List { + for _, arg := range field.Names { + obj := ObjectOf(j, arg) + var ssaobj *ssa.Parameter + for _, param := range ssafn.Params { + if param.Object() == obj { + ssaobj = param + break + } + } + if ssaobj == nil { + continue + } + refs := ssaobj.Referrers() + if refs == nil { + continue + } + if len(FilterDebug(*refs)) != 0 { + continue + } + + assigned := false + ast.Inspect(body, func(node ast.Node) bool { + assign, ok := node.(*ast.AssignStmt) + if !ok { + return true + } + for _, lhs := range assign.Lhs { + ident, ok := lhs.(*ast.Ident) + if !ok { + continue + } + if ObjectOf(j, ident) == obj { + assigned = true + return false + } + } + return true + }) + if assigned { + j.Errorf(arg, "argument %s is overwritten before first use", arg) + } + } + } + return true + } + Inspect(ssafn.Syntax(), fn) + } +} + +func (c *Checker) CheckIneffectiveLoop(j *lint.Job) { + // This check detects some, but not all unconditional loop exits. + // We give up in the following cases: + // + // - a goto anywhere in the loop. The goto might skip over our + // return, and we don't check that it doesn't. + // + // - any nested, unlabelled continue, even if it is in another + // loop or closure. + fn := func(node ast.Node) bool { + var body *ast.BlockStmt + switch fn := node.(type) { + case *ast.FuncDecl: + body = fn.Body + case *ast.FuncLit: + body = fn.Body + default: + return true + } + if body == nil { + return true + } + labels := map[*ast.Object]ast.Stmt{} + ast.Inspect(body, func(node ast.Node) bool { + label, ok := node.(*ast.LabeledStmt) + if !ok { + return true + } + labels[label.Label.Obj] = label.Stmt + return true + }) + + ast.Inspect(body, func(node ast.Node) bool { + var loop ast.Node + var body *ast.BlockStmt + switch node := node.(type) { + case *ast.ForStmt: + body = node.Body + loop = node + case *ast.RangeStmt: + typ := TypeOf(j, node.X) + if _, ok := typ.Underlying().(*types.Map); ok { + // looping once over a map is a valid pattern for + // getting an arbitrary element. + return true + } + body = node.Body + loop = node + default: + return true + } + if len(body.List) < 2 { + // avoid flagging the somewhat common pattern of using + // a range loop to get the first element in a slice, + // or the first rune in a string. + return true + } + var unconditionalExit ast.Node + hasBranching := false + for _, stmt := range body.List { + switch stmt := stmt.(type) { + case *ast.BranchStmt: + switch stmt.Tok { + case token.BREAK: + if stmt.Label == nil || labels[stmt.Label.Obj] == loop { + unconditionalExit = stmt + } + case token.CONTINUE: + if stmt.Label == nil || labels[stmt.Label.Obj] == loop { + unconditionalExit = nil + return false + } + } + case *ast.ReturnStmt: + unconditionalExit = stmt + case *ast.IfStmt, *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt, *ast.SelectStmt: + hasBranching = true + } + } + if unconditionalExit == nil || !hasBranching { + return false + } + ast.Inspect(body, func(node ast.Node) bool { + if branch, ok := node.(*ast.BranchStmt); ok { + + switch branch.Tok { + case token.GOTO: + unconditionalExit = nil + return false + case token.CONTINUE: + if branch.Label != nil && labels[branch.Label.Obj] != loop { + return true + } + unconditionalExit = nil + return false + } + } + return true + }) + if unconditionalExit != nil { + j.Errorf(unconditionalExit, "the surrounding loop is unconditionally terminated") + } + return true + }) + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckNilContext(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + if len(call.Args) == 0 { + return true + } + if typ, ok := TypeOf(j, call.Args[0]).(*types.Basic); !ok || typ.Kind() != types.UntypedNil { + return true + } + sig, ok := TypeOf(j, call.Fun).(*types.Signature) + if !ok { + return true + } + if sig.Params().Len() == 0 { + return true + } + if !IsType(sig.Params().At(0).Type(), "context.Context") { + return true + } + j.Errorf(call.Args[0], + "do not pass a nil Context, even if a function permits it; pass context.TODO if you are unsure about which Context to use") + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckSeeker(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + if sel.Sel.Name != "Seek" { + return true + } + if len(call.Args) != 2 { + return true + } + arg0, ok := call.Args[0].(*ast.SelectorExpr) + if !ok { + return true + } + switch arg0.Sel.Name { + case "SeekStart", "SeekCurrent", "SeekEnd": + default: + return true + } + pkg, ok := arg0.X.(*ast.Ident) + if !ok { + return true + } + if pkg.Name != "io" { + return true + } + j.Errorf(call, "the first argument of io.Seeker is the offset, but an io.Seek* constant is being used instead") + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckIneffectiveAppend(j *lint.Job) { + isAppend := func(ins ssa.Value) bool { + call, ok := ins.(*ssa.Call) + if !ok { + return false + } + if call.Call.IsInvoke() { + return false + } + if builtin, ok := call.Call.Value.(*ssa.Builtin); !ok || builtin.Name() != "append" { + return false + } + return true + } + + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + val, ok := ins.(ssa.Value) + if !ok || !isAppend(val) { + continue + } + + isUsed := false + visited := map[ssa.Instruction]bool{} + var walkRefs func(refs []ssa.Instruction) + walkRefs = func(refs []ssa.Instruction) { + loop: + for _, ref := range refs { + if visited[ref] { + continue + } + visited[ref] = true + if _, ok := ref.(*ssa.DebugRef); ok { + continue + } + switch ref := ref.(type) { + case *ssa.Phi: + walkRefs(*ref.Referrers()) + case *ssa.Sigma: + walkRefs(*ref.Referrers()) + case ssa.Value: + if !isAppend(ref) { + isUsed = true + } else { + walkRefs(*ref.Referrers()) + } + case ssa.Instruction: + isUsed = true + break loop + } + } + } + refs := val.Referrers() + if refs == nil { + continue + } + walkRefs(*refs) + if !isUsed { + j.Errorf(ins, "this result of append is never used, except maybe in other appends") + } + } + } + } +} + +func (c *Checker) CheckConcurrentTesting(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + gostmt, ok := ins.(*ssa.Go) + if !ok { + continue + } + var fn *ssa.Function + switch val := gostmt.Call.Value.(type) { + case *ssa.Function: + fn = val + case *ssa.MakeClosure: + fn = val.Fn.(*ssa.Function) + default: + continue + } + if fn.Blocks == nil { + continue + } + for _, block := range fn.Blocks { + for _, ins := range block.Instrs { + call, ok := ins.(*ssa.Call) + if !ok { + continue + } + if call.Call.IsInvoke() { + continue + } + callee := call.Call.StaticCallee() + if callee == nil { + continue + } + recv := callee.Signature.Recv() + if recv == nil { + continue + } + if !IsType(recv.Type(), "*testing.common") { + continue + } + fn, ok := call.Call.StaticCallee().Object().(*types.Func) + if !ok { + continue + } + name := fn.Name() + switch name { + case "FailNow", "Fatal", "Fatalf", "SkipNow", "Skip", "Skipf": + default: + continue + } + j.Errorf(gostmt, "the goroutine calls T.%s, which must be called in the same goroutine as the test", name) + } + } + } + } + } +} + +func (c *Checker) CheckCyclicFinalizer(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + node := c.funcDescs.CallGraph.CreateNode(ssafn) + for _, edge := range node.Out { + if edge.Callee.Func.RelString(nil) != "runtime.SetFinalizer" { + continue + } + arg0 := edge.Site.Common().Args[0] + if iface, ok := arg0.(*ssa.MakeInterface); ok { + arg0 = iface.X + } + unop, ok := arg0.(*ssa.UnOp) + if !ok { + continue + } + v, ok := unop.X.(*ssa.Alloc) + if !ok { + continue + } + arg1 := edge.Site.Common().Args[1] + if iface, ok := arg1.(*ssa.MakeInterface); ok { + arg1 = iface.X + } + mc, ok := arg1.(*ssa.MakeClosure) + if !ok { + continue + } + for _, b := range mc.Bindings { + if b == v { + pos := j.Program.DisplayPosition(mc.Fn.Pos()) + j.Errorf(edge.Site, "the finalizer closes over the object, preventing the finalizer from ever running (at %s)", pos) + } + } + } + } +} + +func (c *Checker) CheckSliceOutOfBounds(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + ia, ok := ins.(*ssa.IndexAddr) + if !ok { + continue + } + if _, ok := ia.X.Type().Underlying().(*types.Slice); !ok { + continue + } + sr, ok1 := c.funcDescs.Get(ssafn).Ranges[ia.X].(vrp.SliceInterval) + idxr, ok2 := c.funcDescs.Get(ssafn).Ranges[ia.Index].(vrp.IntInterval) + if !ok1 || !ok2 || !sr.IsKnown() || !idxr.IsKnown() || sr.Length.Empty() || idxr.Empty() { + continue + } + if idxr.Lower.Cmp(sr.Length.Upper) >= 0 { + j.Errorf(ia, "index out of bounds") + } + } + } + } +} + +func (c *Checker) CheckDeferLock(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + instrs := FilterDebug(block.Instrs) + if len(instrs) < 2 { + continue + } + for i, ins := range instrs[:len(instrs)-1] { + call, ok := ins.(*ssa.Call) + if !ok { + continue + } + if !IsCallTo(call.Common(), "(*sync.Mutex).Lock") && !IsCallTo(call.Common(), "(*sync.RWMutex).RLock") { + continue + } + nins, ok := instrs[i+1].(*ssa.Defer) + if !ok { + continue + } + if !IsCallTo(&nins.Call, "(*sync.Mutex).Lock") && !IsCallTo(&nins.Call, "(*sync.RWMutex).RLock") { + continue + } + if call.Common().Args[0] != nins.Call.Args[0] { + continue + } + name := shortCallName(call.Common()) + alt := "" + switch name { + case "Lock": + alt = "Unlock" + case "RLock": + alt = "RUnlock" + } + j.Errorf(nins, "deferring %s right after having locked already; did you mean to defer %s?", name, alt) + } + } + } +} + +func (c *Checker) CheckNaNComparison(j *lint.Job) { + isNaN := func(v ssa.Value) bool { + call, ok := v.(*ssa.Call) + if !ok { + return false + } + return IsCallTo(call.Common(), "math.NaN") + } + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + ins, ok := ins.(*ssa.BinOp) + if !ok { + continue + } + if isNaN(ins.X) || isNaN(ins.Y) { + j.Errorf(ins, "no value is equal to NaN, not even NaN itself") + } + } + } + } +} + +func (c *Checker) CheckInfiniteRecursion(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + node := c.funcDescs.CallGraph.CreateNode(ssafn) + for _, edge := range node.Out { + if edge.Callee != node { + continue + } + if _, ok := edge.Site.(*ssa.Go); ok { + // Recursively spawning goroutines doesn't consume + // stack space infinitely, so don't flag it. + continue + } + + block := edge.Site.Block() + canReturn := false + for _, b := range ssafn.Blocks { + if block.Dominates(b) { + continue + } + if len(b.Instrs) == 0 { + continue + } + if _, ok := b.Instrs[len(b.Instrs)-1].(*ssa.Return); ok { + canReturn = true + break + } + } + if canReturn { + continue + } + j.Errorf(edge.Site, "infinite recursive call") + } + } +} + +func objectName(obj types.Object) string { + if obj == nil { + return "" + } + var name string + if obj.Pkg() != nil && obj.Pkg().Scope().Lookup(obj.Name()) == obj { + var s string + s = obj.Pkg().Path() + if s != "" { + name += s + "." + } + } + name += obj.Name() + return name +} + +func isName(j *lint.Job, expr ast.Expr, name string) bool { + var obj types.Object + switch expr := expr.(type) { + case *ast.Ident: + obj = ObjectOf(j, expr) + case *ast.SelectorExpr: + obj = ObjectOf(j, expr.Sel) + } + return objectName(obj) == name +} + +func (c *Checker) CheckLeakyTimeTick(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + if IsInMain(j, ssafn) || IsInTest(j, ssafn) { + continue + } + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + call, ok := ins.(*ssa.Call) + if !ok || !IsCallTo(call.Common(), "time.Tick") { + continue + } + if c.funcDescs.Get(call.Parent()).Infinite { + continue + } + j.Errorf(call, "using time.Tick leaks the underlying ticker, consider using it only in endless functions, tests and the main package, and use time.NewTicker here") + } + } + } +} + +func (c *Checker) CheckDoubleNegation(j *lint.Job) { + fn := func(node ast.Node) bool { + unary1, ok := node.(*ast.UnaryExpr) + if !ok { + return true + } + unary2, ok := unary1.X.(*ast.UnaryExpr) + if !ok { + return true + } + if unary1.Op != token.NOT || unary2.Op != token.NOT { + return true + } + j.Errorf(unary1, "negating a boolean twice has no effect; is this a typo?") + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func hasSideEffects(node ast.Node) bool { + dynamic := false + ast.Inspect(node, func(node ast.Node) bool { + switch node := node.(type) { + case *ast.CallExpr: + dynamic = true + return false + case *ast.UnaryExpr: + if node.Op == token.ARROW { + dynamic = true + return false + } + } + return true + }) + return dynamic +} + +func (c *Checker) CheckRepeatedIfElse(j *lint.Job) { + seen := map[ast.Node]bool{} + + var collectConds func(ifstmt *ast.IfStmt, inits []ast.Stmt, conds []ast.Expr) ([]ast.Stmt, []ast.Expr) + collectConds = func(ifstmt *ast.IfStmt, inits []ast.Stmt, conds []ast.Expr) ([]ast.Stmt, []ast.Expr) { + seen[ifstmt] = true + if ifstmt.Init != nil { + inits = append(inits, ifstmt.Init) + } + conds = append(conds, ifstmt.Cond) + if elsestmt, ok := ifstmt.Else.(*ast.IfStmt); ok { + return collectConds(elsestmt, inits, conds) + } + return inits, conds + } + fn := func(node ast.Node) bool { + ifstmt, ok := node.(*ast.IfStmt) + if !ok { + return true + } + if seen[ifstmt] { + return true + } + inits, conds := collectConds(ifstmt, nil, nil) + if len(inits) > 0 { + return true + } + for _, cond := range conds { + if hasSideEffects(cond) { + return true + } + } + counts := map[string]int{} + for _, cond := range conds { + s := Render(j, cond) + counts[s]++ + if counts[s] == 2 { + j.Errorf(cond, "this condition occurs multiple times in this if/else if chain") + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckSillyBitwiseOps(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + ins, ok := ins.(*ssa.BinOp) + if !ok { + continue + } + + if c, ok := ins.Y.(*ssa.Const); !ok || c.Value == nil || c.Value.Kind() != constant.Int || c.Uint64() != 0 { + continue + } + switch ins.Op { + case token.AND, token.OR, token.XOR: + default: + // we do not flag shifts because too often, x<<0 is part + // of a pattern, x<<0, x<<8, x<<16, ... + continue + } + path, _ := astutil.PathEnclosingInterval(j.File(ins), ins.Pos(), ins.Pos()) + if len(path) == 0 { + continue + } + if node, ok := path[0].(*ast.BinaryExpr); !ok || !IsZero(node.Y) { + continue + } + + switch ins.Op { + case token.AND: + j.Errorf(ins, "x & 0 always equals 0") + case token.OR, token.XOR: + j.Errorf(ins, "x %s 0 always equals x", ins.Op) + } + } + } + } +} + +func (c *Checker) CheckNonOctalFileMode(j *lint.Job) { + fn := func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + sig, ok := TypeOf(j, call.Fun).(*types.Signature) + if !ok { + return true + } + n := sig.Params().Len() + var args []int + for i := 0; i < n; i++ { + typ := sig.Params().At(i).Type() + if IsType(typ, "os.FileMode") { + args = append(args, i) + } + } + for _, i := range args { + lit, ok := call.Args[i].(*ast.BasicLit) + if !ok { + continue + } + if len(lit.Value) == 3 && + lit.Value[0] != '0' && + lit.Value[0] >= '0' && lit.Value[0] <= '7' && + lit.Value[1] >= '0' && lit.Value[1] <= '7' && + lit.Value[2] >= '0' && lit.Value[2] <= '7' { + + v, err := strconv.ParseInt(lit.Value, 10, 64) + if err != nil { + continue + } + j.Errorf(call.Args[i], "file mode '%s' evaluates to %#o; did you mean '0%s'?", lit.Value, v, lit.Value) + } + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) CheckPureFunctions(j *lint.Job) { +fnLoop: + for _, ssafn := range j.Program.InitialFunctions { + if IsInTest(j, ssafn) { + params := ssafn.Signature.Params() + for i := 0; i < params.Len(); i++ { + param := params.At(i) + if IsType(param.Type(), "*testing.B") { + // Ignore discarded pure functions in code related + // to benchmarks. Instead of matching BenchmarkFoo + // functions, we match any function accepting a + // *testing.B. Benchmarks sometimes call generic + // functions for doing the actual work, and + // checking for the parameter is a lot easier and + // faster than analyzing call trees. + continue fnLoop + } + } + } + + for _, b := range ssafn.Blocks { + for _, ins := range b.Instrs { + ins, ok := ins.(*ssa.Call) + if !ok { + continue + } + refs := ins.Referrers() + if refs == nil || len(FilterDebug(*refs)) > 0 { + continue + } + callee := ins.Common().StaticCallee() + if callee == nil { + continue + } + if c.funcDescs.Get(callee).Pure && !c.funcDescs.Get(callee).Stub { + j.Errorf(ins, "%s is a pure function but its return value is ignored", callee.Name()) + continue + } + } + } + } +} + +func (c *Checker) isDeprecated(j *lint.Job, ident *ast.Ident) (bool, string) { + obj := ObjectOf(j, ident) + if obj.Pkg() == nil { + return false, "" + } + alt := c.deprecatedObjs[obj] + return alt != "", alt +} + +func (c *Checker) CheckDeprecated(j *lint.Job) { + // Selectors can appear outside of function literals, e.g. when + // declaring package level variables. + + var ssafn *ssa.Function + stack := 0 + fn := func(node ast.Node) bool { + if node == nil { + stack-- + } else { + stack++ + } + if stack == 1 { + ssafn = nil + } + if fn, ok := node.(*ast.FuncDecl); ok { + ssafn = j.Program.SSA.FuncValue(ObjectOf(j, fn.Name).(*types.Func)) + } + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + + obj := ObjectOf(j, sel.Sel) + if obj.Pkg() == nil { + return true + } + nodePkg := j.NodePackage(node).Pkg + if nodePkg == obj.Pkg() || obj.Pkg().Path()+"_test" == nodePkg.Path() { + // Don't flag stuff in our own package + return true + } + if ok, alt := c.isDeprecated(j, sel.Sel); ok { + // Look for the first available alternative, not the first + // version something was deprecated in. If a function was + // deprecated in Go 1.6, an alternative has been available + // already in 1.0, and we're targeting 1.2, it still + // makes sense to use the alternative from 1.0, to be + // future-proof. + minVersion := deprecated.Stdlib[SelectorName(j, sel)].AlternativeAvailableSince + if !IsGoVersion(j, minVersion) { + return true + } + + if ssafn != nil { + if _, ok := c.deprecatedObjs[ssafn.Object()]; ok { + // functions that are deprecated may use deprecated + // symbols + return true + } + } + j.Errorf(sel, "%s is deprecated: %s", Render(j, sel), alt) + return true + } + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} + +func (c *Checker) callChecker(rules map[string]CallCheck) func(j *lint.Job) { + return func(j *lint.Job) { + c.checkCalls(j, rules) + } +} + +func (c *Checker) checkCalls(j *lint.Job, rules map[string]CallCheck) { + for _, ssafn := range j.Program.InitialFunctions { + node := c.funcDescs.CallGraph.CreateNode(ssafn) + for _, edge := range node.Out { + callee := edge.Callee.Func + obj, ok := callee.Object().(*types.Func) + if !ok { + continue + } + + r, ok := rules[obj.FullName()] + if !ok { + continue + } + var args []*Argument + ssaargs := edge.Site.Common().Args + if callee.Signature.Recv() != nil { + ssaargs = ssaargs[1:] + } + for _, arg := range ssaargs { + if iarg, ok := arg.(*ssa.MakeInterface); ok { + arg = iarg.X + } + vr := c.funcDescs.Get(edge.Site.Parent()).Ranges[arg] + args = append(args, &Argument{Value: Value{arg, vr}}) + } + call := &Call{ + Job: j, + Instr: edge.Site, + Args: args, + Checker: c, + Parent: edge.Site.Parent(), + } + r(call) + for idx, arg := range call.Args { + _ = idx + for _, e := range arg.invalids { + // path, _ := astutil.PathEnclosingInterval(f.File, edge.Site.Pos(), edge.Site.Pos()) + // if len(path) < 2 { + // continue + // } + // astcall, ok := path[0].(*ast.CallExpr) + // if !ok { + // continue + // } + // j.Errorf(astcall.Args[idx], "%s", e) + + j.Errorf(edge.Site, "%s", e) + } + } + for _, e := range call.invalids { + j.Errorf(call.Instr.Common(), "%s", e) + } + } + } +} + +func unwrapFunction(val ssa.Value) *ssa.Function { + switch val := val.(type) { + case *ssa.Function: + return val + case *ssa.MakeClosure: + return val.Fn.(*ssa.Function) + default: + return nil + } +} + +func shortCallName(call *ssa.CallCommon) string { + if call.IsInvoke() { + return "" + } + switch v := call.Value.(type) { + case *ssa.Function: + fn, ok := v.Object().(*types.Func) + if !ok { + return "" + } + return fn.Name() + case *ssa.Builtin: + return v.Name() + } + return "" +} + +func hasCallTo(block *ssa.BasicBlock, name string) bool { + for _, ins := range block.Instrs { + call, ok := ins.(*ssa.Call) + if !ok { + continue + } + if IsCallTo(call.Common(), name) { + return true + } + } + return false +} + +func (c *Checker) CheckWriterBufferModified(j *lint.Job) { + // TODO(dh): this might be a good candidate for taint analysis. + // Taint the argument as MUST_NOT_MODIFY, then propagate that + // through functions like bytes.Split + + for _, ssafn := range j.Program.InitialFunctions { + sig := ssafn.Signature + if ssafn.Name() != "Write" || sig.Recv() == nil || sig.Params().Len() != 1 || sig.Results().Len() != 2 { + continue + } + tArg, ok := sig.Params().At(0).Type().(*types.Slice) + if !ok { + continue + } + if basic, ok := tArg.Elem().(*types.Basic); !ok || basic.Kind() != types.Byte { + continue + } + if basic, ok := sig.Results().At(0).Type().(*types.Basic); !ok || basic.Kind() != types.Int { + continue + } + if named, ok := sig.Results().At(1).Type().(*types.Named); !ok || !IsType(named, "error") { + continue + } + + for _, block := range ssafn.Blocks { + for _, ins := range block.Instrs { + switch ins := ins.(type) { + case *ssa.Store: + addr, ok := ins.Addr.(*ssa.IndexAddr) + if !ok { + continue + } + if addr.X != ssafn.Params[1] { + continue + } + j.Errorf(ins, "io.Writer.Write must not modify the provided buffer, not even temporarily") + case *ssa.Call: + if !IsCallTo(ins.Common(), "append") { + continue + } + if ins.Common().Args[0] != ssafn.Params[1] { + continue + } + j.Errorf(ins, "io.Writer.Write must not modify the provided buffer, not even temporarily") + } + } + } + } +} + +func loopedRegexp(name string) CallCheck { + return func(call *Call) { + if len(extractConsts(call.Args[0].Value.Value)) == 0 { + return + } + if !call.Checker.isInLoop(call.Instr.Block()) { + return + } + call.Invalid(fmt.Sprintf("calling %s in a loop has poor performance, consider using regexp.Compile", name)) + } +} + +func (c *Checker) CheckEmptyBranch(j *lint.Job) { + for _, ssafn := range j.Program.InitialFunctions { + if ssafn.Syntax() == nil { + continue + } + if IsGenerated(j.File(ssafn.Syntax())) { + continue + } + if IsExample(ssafn) { + continue + } + fn := func(node ast.Node) bool { + ifstmt, ok := node.(*ast.IfStmt) + if !ok { + return true + } + if ifstmt.Else != nil { + b, ok := ifstmt.Else.(*ast.BlockStmt) + if !ok || len(b.List) != 0 { + return true + } + j.Errorf(ifstmt.Else, "empty branch") + } + if len(ifstmt.Body.List) != 0 { + return true + } + j.Errorf(ifstmt, "empty branch") + return true + } + Inspect(ssafn.Syntax(), fn) + } +} + +func (c *Checker) CheckMapBytesKey(j *lint.Job) { + for _, fn := range j.Program.InitialFunctions { + for _, b := range fn.Blocks { + insLoop: + for _, ins := range b.Instrs { + // find []byte -> string conversions + conv, ok := ins.(*ssa.Convert) + if !ok || conv.Type() != types.Universe.Lookup("string").Type() { + continue + } + if s, ok := conv.X.Type().(*types.Slice); !ok || s.Elem() != types.Universe.Lookup("byte").Type() { + continue + } + refs := conv.Referrers() + // need at least two (DebugRef) references: the + // conversion and the *ast.Ident + if refs == nil || len(*refs) < 2 { + continue + } + ident := false + // skip first reference, that's the conversion itself + for _, ref := range (*refs)[1:] { + switch ref := ref.(type) { + case *ssa.DebugRef: + if _, ok := ref.Expr.(*ast.Ident); !ok { + // the string seems to be used somewhere + // unexpected; the default branch should + // catch this already, but be safe + continue insLoop + } else { + ident = true + } + case *ssa.Lookup: + default: + // the string is used somewhere else than a + // map lookup + continue insLoop + } + } + + // the result of the conversion wasn't assigned to an + // identifier + if !ident { + continue + } + j.Errorf(conv, "m[string(key)] would be more efficient than k := string(key); m[k]") + } + } + } +} + +func (c *Checker) CheckRangeStringRunes(j *lint.Job) { + sharedcheck.CheckRangeStringRunes(j) +} + +func (c *Checker) CheckSelfAssignment(j *lint.Job) { + fn := func(node ast.Node) bool { + assign, ok := node.(*ast.AssignStmt) + if !ok { + return true + } + if assign.Tok != token.ASSIGN || len(assign.Lhs) != len(assign.Rhs) { + return true + } + for i, stmt := range assign.Lhs { + rlh := Render(j, stmt) + rrh := Render(j, assign.Rhs[i]) + if rlh == rrh { + j.Errorf(assign, "self-assignment of %s to %s", rrh, rlh) + } + } + return true + } + for _, f := range c.filterGenerated(j.Program.Files) { + ast.Inspect(f, fn) + } +} + +func buildTagsIdentical(s1, s2 []string) bool { + if len(s1) != len(s2) { + return false + } + s1s := make([]string, len(s1)) + copy(s1s, s1) + sort.Strings(s1s) + s2s := make([]string, len(s2)) + copy(s2s, s2) + sort.Strings(s2s) + for i, s := range s1s { + if s != s2s[i] { + return false + } + } + return true +} + +func (c *Checker) CheckDuplicateBuildConstraints(job *lint.Job) { + for _, f := range c.filterGenerated(job.Program.Files) { + constraints := buildTags(f) + for i, constraint1 := range constraints { + for j, constraint2 := range constraints { + if i >= j { + continue + } + if buildTagsIdentical(constraint1, constraint2) { + job.Errorf(f, "identical build constraints %q and %q", + strings.Join(constraint1, " "), + strings.Join(constraint2, " ")) + } + } + } + } +} + +func (c *Checker) CheckSillyRegexp(j *lint.Job) { + // We could use the rule checking engine for this, but the + // arguments aren't really invalid. + for _, fn := range j.Program.InitialFunctions { + for _, b := range fn.Blocks { + for _, ins := range b.Instrs { + call, ok := ins.(*ssa.Call) + if !ok { + continue + } + switch CallName(call.Common()) { + case "regexp.MustCompile", "regexp.Compile", "regexp.Match", "regexp.MatchReader", "regexp.MatchString": + default: + continue + } + c, ok := call.Common().Args[0].(*ssa.Const) + if !ok { + continue + } + s := constant.StringVal(c.Value) + re, err := syntax.Parse(s, 0) + if err != nil { + continue + } + if re.Op != syntax.OpLiteral && re.Op != syntax.OpEmptyMatch { + continue + } + j.Errorf(call, "regular expression does not contain any meta characters") + } + } + } +} + +func (c *Checker) CheckMissingEnumTypesInDeclaration(j *lint.Job) { + fn := func(node ast.Node) bool { + decl, ok := node.(*ast.GenDecl) + if !ok { + return true + } + if !decl.Lparen.IsValid() { + // not a parenthesised gendecl + // + // TODO(dh): do we need this check, considering we require + // decl.Specs to contain 2+ elements? + return true + } + if decl.Tok != token.CONST { + return true + } + if len(decl.Specs) < 2 { + return true + } + if decl.Specs[0].(*ast.ValueSpec).Type == nil { + // first constant doesn't have a type + return true + } + for i, spec := range decl.Specs { + spec := spec.(*ast.ValueSpec) + if len(spec.Names) != 1 || len(spec.Values) != 1 { + return true + } + switch v := spec.Values[0].(type) { + case *ast.BasicLit: + case *ast.UnaryExpr: + if _, ok := v.X.(*ast.BasicLit); !ok { + return true + } + default: + // if it's not a literal it might be typed, such as + // time.Microsecond = 1000 * Nanosecond + return true + } + if i == 0 { + continue + } + if spec.Type != nil { + return true + } + } + j.Errorf(decl, "only the first constant has an explicit type") + return true + } + for _, f := range j.Program.Files { + ast.Inspect(f, fn) + } +} diff --git a/vendor/honnef.co/go/tools/staticcheck/rules.go b/vendor/honnef.co/go/tools/staticcheck/rules.go new file mode 100644 index 00000000000..d6af573c218 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/rules.go @@ -0,0 +1,322 @@ +package staticcheck + +import ( + "fmt" + "go/constant" + "go/types" + "net" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "time" + "unicode/utf8" + + "honnef.co/go/tools/lint" + . "honnef.co/go/tools/lint/lintdsl" + "honnef.co/go/tools/ssa" + "honnef.co/go/tools/staticcheck/vrp" +) + +const ( + MsgInvalidHostPort = "invalid port or service name in host:port pair" + MsgInvalidUTF8 = "argument is not a valid UTF-8 encoded string" + MsgNonUniqueCutset = "cutset contains duplicate characters" +) + +type Call struct { + Job *lint.Job + Instr ssa.CallInstruction + Args []*Argument + + Checker *Checker + Parent *ssa.Function + + invalids []string +} + +func (c *Call) Invalid(msg string) { + c.invalids = append(c.invalids, msg) +} + +type Argument struct { + Value Value + invalids []string +} + +func (arg *Argument) Invalid(msg string) { + arg.invalids = append(arg.invalids, msg) +} + +type Value struct { + Value ssa.Value + Range vrp.Range +} + +type CallCheck func(call *Call) + +func extractConsts(v ssa.Value) []*ssa.Const { + switch v := v.(type) { + case *ssa.Const: + return []*ssa.Const{v} + case *ssa.MakeInterface: + return extractConsts(v.X) + default: + return nil + } +} + +func ValidateRegexp(v Value) error { + for _, c := range extractConsts(v.Value) { + if c.Value == nil { + continue + } + if c.Value.Kind() != constant.String { + continue + } + s := constant.StringVal(c.Value) + if _, err := regexp.Compile(s); err != nil { + return err + } + } + return nil +} + +func ValidateTimeLayout(v Value) error { + for _, c := range extractConsts(v.Value) { + if c.Value == nil { + continue + } + if c.Value.Kind() != constant.String { + continue + } + s := constant.StringVal(c.Value) + s = strings.Replace(s, "_", " ", -1) + s = strings.Replace(s, "Z", "-", -1) + _, err := time.Parse(s, s) + if err != nil { + return err + } + } + return nil +} + +func ValidateURL(v Value) error { + for _, c := range extractConsts(v.Value) { + if c.Value == nil { + continue + } + if c.Value.Kind() != constant.String { + continue + } + s := constant.StringVal(c.Value) + _, err := url.Parse(s) + if err != nil { + return fmt.Errorf("%q is not a valid URL: %s", s, err) + } + } + return nil +} + +func IntValue(v Value, z vrp.Z) bool { + r, ok := v.Range.(vrp.IntInterval) + if !ok || !r.IsKnown() { + return false + } + if r.Lower != r.Upper { + return false + } + if r.Lower.Cmp(z) == 0 { + return true + } + return false +} + +func InvalidUTF8(v Value) bool { + for _, c := range extractConsts(v.Value) { + if c.Value == nil { + continue + } + if c.Value.Kind() != constant.String { + continue + } + s := constant.StringVal(c.Value) + if !utf8.ValidString(s) { + return true + } + } + return false +} + +func UnbufferedChannel(v Value) bool { + r, ok := v.Range.(vrp.ChannelInterval) + if !ok || !r.IsKnown() { + return false + } + if r.Size.Lower.Cmp(vrp.NewZ(0)) == 0 && + r.Size.Upper.Cmp(vrp.NewZ(0)) == 0 { + return true + } + return false +} + +func Pointer(v Value) bool { + switch v.Value.Type().Underlying().(type) { + case *types.Pointer, *types.Interface: + return true + } + return false +} + +func ConvertedFromInt(v Value) bool { + conv, ok := v.Value.(*ssa.Convert) + if !ok { + return false + } + b, ok := conv.X.Type().Underlying().(*types.Basic) + if !ok { + return false + } + if (b.Info() & types.IsInteger) == 0 { + return false + } + return true +} + +func validEncodingBinaryType(j *lint.Job, typ types.Type) bool { + typ = typ.Underlying() + switch typ := typ.(type) { + case *types.Basic: + switch typ.Kind() { + case types.Uint8, types.Uint16, types.Uint32, types.Uint64, + types.Int8, types.Int16, types.Int32, types.Int64, + types.Float32, types.Float64, types.Complex64, types.Complex128, types.Invalid: + return true + case types.Bool: + return IsGoVersion(j, 8) + } + return false + case *types.Struct: + n := typ.NumFields() + for i := 0; i < n; i++ { + if !validEncodingBinaryType(j, typ.Field(i).Type()) { + return false + } + } + return true + case *types.Array: + return validEncodingBinaryType(j, typ.Elem()) + case *types.Interface: + // we can't determine if it's a valid type or not + return true + } + return false +} + +func CanBinaryMarshal(j *lint.Job, v Value) bool { + typ := v.Value.Type().Underlying() + if ttyp, ok := typ.(*types.Pointer); ok { + typ = ttyp.Elem().Underlying() + } + if ttyp, ok := typ.(interface { + Elem() types.Type + }); ok { + if _, ok := ttyp.(*types.Pointer); !ok { + typ = ttyp.Elem() + } + } + + return validEncodingBinaryType(j, typ) +} + +func RepeatZeroTimes(name string, arg int) CallCheck { + return func(call *Call) { + arg := call.Args[arg] + if IntValue(arg.Value, vrp.NewZ(0)) { + arg.Invalid(fmt.Sprintf("calling %s with n == 0 will return no results, did you mean -1?", name)) + } + } +} + +func validateServiceName(s string) bool { + if len(s) < 1 || len(s) > 15 { + return false + } + if s[0] == '-' || s[len(s)-1] == '-' { + return false + } + if strings.Contains(s, "--") { + return false + } + hasLetter := false + for _, r := range s { + if (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') { + hasLetter = true + continue + } + if r >= '0' && r <= '9' { + continue + } + return false + } + return hasLetter +} + +func validatePort(s string) bool { + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return validateServiceName(s) + } + return n >= 0 && n <= 65535 +} + +func ValidHostPort(v Value) bool { + for _, k := range extractConsts(v.Value) { + if k.Value == nil { + continue + } + if k.Value.Kind() != constant.String { + continue + } + s := constant.StringVal(k.Value) + _, port, err := net.SplitHostPort(s) + if err != nil { + return false + } + // TODO(dh): check hostname + if !validatePort(port) { + return false + } + } + return true +} + +// ConvertedFrom reports whether value v was converted from type typ. +func ConvertedFrom(v Value, typ string) bool { + change, ok := v.Value.(*ssa.ChangeType) + return ok && IsType(change.X.Type(), typ) +} + +func UniqueStringCutset(v Value) bool { + for _, c := range extractConsts(v.Value) { + if c.Value == nil { + continue + } + if c.Value.Kind() != constant.String { + continue + } + s := constant.StringVal(c.Value) + rs := runeSlice(s) + if len(rs) < 2 { + continue + } + sort.Sort(rs) + for i, r := range rs[1:] { + if rs[i] == r { + return false + } + } + } + return true +} diff --git a/vendor/honnef.co/go/tools/staticcheck/vrp/channel.go b/vendor/honnef.co/go/tools/staticcheck/vrp/channel.go new file mode 100644 index 00000000000..0ef73787ba3 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/vrp/channel.go @@ -0,0 +1,73 @@ +package vrp + +import ( + "fmt" + + "honnef.co/go/tools/ssa" +) + +type ChannelInterval struct { + Size IntInterval +} + +func (c ChannelInterval) Union(other Range) Range { + i, ok := other.(ChannelInterval) + if !ok { + i = ChannelInterval{EmptyIntInterval} + } + if c.Size.Empty() || !c.Size.IsKnown() { + return i + } + if i.Size.Empty() || !i.Size.IsKnown() { + return c + } + return ChannelInterval{ + Size: c.Size.Union(i.Size).(IntInterval), + } +} + +func (c ChannelInterval) String() string { + return c.Size.String() +} + +func (c ChannelInterval) IsKnown() bool { + return c.Size.IsKnown() +} + +type MakeChannelConstraint struct { + aConstraint + Buffer ssa.Value +} +type ChannelChangeTypeConstraint struct { + aConstraint + X ssa.Value +} + +func NewMakeChannelConstraint(buffer, y ssa.Value) Constraint { + return &MakeChannelConstraint{NewConstraint(y), buffer} +} +func NewChannelChangeTypeConstraint(x, y ssa.Value) Constraint { + return &ChannelChangeTypeConstraint{NewConstraint(y), x} +} + +func (c *MakeChannelConstraint) Operands() []ssa.Value { return []ssa.Value{c.Buffer} } +func (c *ChannelChangeTypeConstraint) Operands() []ssa.Value { return []ssa.Value{c.X} } + +func (c *MakeChannelConstraint) String() string { + return fmt.Sprintf("%s = make(chan, %s)", c.Y().Name(), c.Buffer.Name()) +} +func (c *ChannelChangeTypeConstraint) String() string { + return fmt.Sprintf("%s = changetype(%s)", c.Y().Name(), c.X.Name()) +} + +func (c *MakeChannelConstraint) Eval(g *Graph) Range { + i, ok := g.Range(c.Buffer).(IntInterval) + if !ok { + return ChannelInterval{NewIntInterval(NewZ(0), PInfinity)} + } + if i.Lower.Sign() == -1 { + i.Lower = NewZ(0) + } + return ChannelInterval{i} +} +func (c *ChannelChangeTypeConstraint) Eval(g *Graph) Range { return g.Range(c.X) } diff --git a/vendor/honnef.co/go/tools/staticcheck/vrp/int.go b/vendor/honnef.co/go/tools/staticcheck/vrp/int.go new file mode 100644 index 00000000000..926bb7af3d6 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/vrp/int.go @@ -0,0 +1,476 @@ +package vrp + +import ( + "fmt" + "go/token" + "go/types" + "math/big" + + "honnef.co/go/tools/ssa" +) + +type Zs []Z + +func (zs Zs) Len() int { + return len(zs) +} + +func (zs Zs) Less(i int, j int) bool { + return zs[i].Cmp(zs[j]) == -1 +} + +func (zs Zs) Swap(i int, j int) { + zs[i], zs[j] = zs[j], zs[i] +} + +type Z struct { + infinity int8 + integer *big.Int +} + +func NewZ(n int64) Z { + return NewBigZ(big.NewInt(n)) +} + +func NewBigZ(n *big.Int) Z { + return Z{integer: n} +} + +func (z1 Z) Infinite() bool { + return z1.infinity != 0 +} + +func (z1 Z) Add(z2 Z) Z { + if z2.Sign() == -1 { + return z1.Sub(z2.Negate()) + } + if z1 == NInfinity { + return NInfinity + } + if z1 == PInfinity { + return PInfinity + } + if z2 == PInfinity { + return PInfinity + } + + if !z1.Infinite() && !z2.Infinite() { + n := &big.Int{} + n.Add(z1.integer, z2.integer) + return NewBigZ(n) + } + + panic(fmt.Sprintf("%s + %s is not defined", z1, z2)) +} + +func (z1 Z) Sub(z2 Z) Z { + if z2.Sign() == -1 { + return z1.Add(z2.Negate()) + } + if !z1.Infinite() && !z2.Infinite() { + n := &big.Int{} + n.Sub(z1.integer, z2.integer) + return NewBigZ(n) + } + + if z1 != PInfinity && z2 == PInfinity { + return NInfinity + } + if z1.Infinite() && !z2.Infinite() { + return Z{infinity: z1.infinity} + } + if z1 == PInfinity && z2 == PInfinity { + return PInfinity + } + panic(fmt.Sprintf("%s - %s is not defined", z1, z2)) +} + +func (z1 Z) Mul(z2 Z) Z { + if (z1.integer != nil && z1.integer.Sign() == 0) || + (z2.integer != nil && z2.integer.Sign() == 0) { + return NewBigZ(&big.Int{}) + } + + if z1.infinity != 0 || z2.infinity != 0 { + return Z{infinity: int8(z1.Sign() * z2.Sign())} + } + + n := &big.Int{} + n.Mul(z1.integer, z2.integer) + return NewBigZ(n) +} + +func (z1 Z) Negate() Z { + if z1.infinity == 1 { + return NInfinity + } + if z1.infinity == -1 { + return PInfinity + } + n := &big.Int{} + n.Neg(z1.integer) + return NewBigZ(n) +} + +func (z1 Z) Sign() int { + if z1.infinity != 0 { + return int(z1.infinity) + } + return z1.integer.Sign() +} + +func (z1 Z) String() string { + if z1 == NInfinity { + return "-∞" + } + if z1 == PInfinity { + return "∞" + } + return fmt.Sprintf("%d", z1.integer) +} + +func (z1 Z) Cmp(z2 Z) int { + if z1.infinity == z2.infinity && z1.infinity != 0 { + return 0 + } + if z1 == PInfinity { + return 1 + } + if z1 == NInfinity { + return -1 + } + if z2 == NInfinity { + return 1 + } + if z2 == PInfinity { + return -1 + } + return z1.integer.Cmp(z2.integer) +} + +func MaxZ(zs ...Z) Z { + if len(zs) == 0 { + panic("Max called with no arguments") + } + if len(zs) == 1 { + return zs[0] + } + ret := zs[0] + for _, z := range zs[1:] { + if z.Cmp(ret) == 1 { + ret = z + } + } + return ret +} + +func MinZ(zs ...Z) Z { + if len(zs) == 0 { + panic("Min called with no arguments") + } + if len(zs) == 1 { + return zs[0] + } + ret := zs[0] + for _, z := range zs[1:] { + if z.Cmp(ret) == -1 { + ret = z + } + } + return ret +} + +var NInfinity = Z{infinity: -1} +var PInfinity = Z{infinity: 1} +var EmptyIntInterval = IntInterval{true, PInfinity, NInfinity} + +func InfinityFor(v ssa.Value) IntInterval { + if b, ok := v.Type().Underlying().(*types.Basic); ok { + if (b.Info() & types.IsUnsigned) != 0 { + return NewIntInterval(NewZ(0), PInfinity) + } + } + return NewIntInterval(NInfinity, PInfinity) +} + +type IntInterval struct { + known bool + Lower Z + Upper Z +} + +func NewIntInterval(l, u Z) IntInterval { + if u.Cmp(l) == -1 { + return EmptyIntInterval + } + return IntInterval{known: true, Lower: l, Upper: u} +} + +func (i IntInterval) IsKnown() bool { + return i.known +} + +func (i IntInterval) Empty() bool { + return i.Lower == PInfinity && i.Upper == NInfinity +} + +func (i IntInterval) IsMaxRange() bool { + return i.Lower == NInfinity && i.Upper == PInfinity +} + +func (i1 IntInterval) Intersection(i2 IntInterval) IntInterval { + if !i1.IsKnown() { + return i2 + } + if !i2.IsKnown() { + return i1 + } + if i1.Empty() || i2.Empty() { + return EmptyIntInterval + } + i3 := NewIntInterval(MaxZ(i1.Lower, i2.Lower), MinZ(i1.Upper, i2.Upper)) + if i3.Lower.Cmp(i3.Upper) == 1 { + return EmptyIntInterval + } + return i3 +} + +func (i1 IntInterval) Union(other Range) Range { + i2, ok := other.(IntInterval) + if !ok { + i2 = EmptyIntInterval + } + if i1.Empty() || !i1.IsKnown() { + return i2 + } + if i2.Empty() || !i2.IsKnown() { + return i1 + } + return NewIntInterval(MinZ(i1.Lower, i2.Lower), MaxZ(i1.Upper, i2.Upper)) +} + +func (i1 IntInterval) Add(i2 IntInterval) IntInterval { + if i1.Empty() || i2.Empty() { + return EmptyIntInterval + } + l1, u1, l2, u2 := i1.Lower, i1.Upper, i2.Lower, i2.Upper + return NewIntInterval(l1.Add(l2), u1.Add(u2)) +} + +func (i1 IntInterval) Sub(i2 IntInterval) IntInterval { + if i1.Empty() || i2.Empty() { + return EmptyIntInterval + } + l1, u1, l2, u2 := i1.Lower, i1.Upper, i2.Lower, i2.Upper + return NewIntInterval(l1.Sub(u2), u1.Sub(l2)) +} + +func (i1 IntInterval) Mul(i2 IntInterval) IntInterval { + if i1.Empty() || i2.Empty() { + return EmptyIntInterval + } + x1, x2 := i1.Lower, i1.Upper + y1, y2 := i2.Lower, i2.Upper + return NewIntInterval( + MinZ(x1.Mul(y1), x1.Mul(y2), x2.Mul(y1), x2.Mul(y2)), + MaxZ(x1.Mul(y1), x1.Mul(y2), x2.Mul(y1), x2.Mul(y2)), + ) +} + +func (i1 IntInterval) String() string { + if !i1.IsKnown() { + return "[⊥, ⊥]" + } + if i1.Empty() { + return "{}" + } + return fmt.Sprintf("[%s, %s]", i1.Lower, i1.Upper) +} + +type IntArithmeticConstraint struct { + aConstraint + A ssa.Value + B ssa.Value + Op token.Token + Fn func(IntInterval, IntInterval) IntInterval +} + +type IntAddConstraint struct{ *IntArithmeticConstraint } +type IntSubConstraint struct{ *IntArithmeticConstraint } +type IntMulConstraint struct{ *IntArithmeticConstraint } + +type IntConversionConstraint struct { + aConstraint + X ssa.Value +} + +type IntIntersectionConstraint struct { + aConstraint + ranges Ranges + A ssa.Value + B ssa.Value + Op token.Token + I IntInterval + resolved bool +} + +type IntIntervalConstraint struct { + aConstraint + I IntInterval +} + +func NewIntArithmeticConstraint(a, b, y ssa.Value, op token.Token, fn func(IntInterval, IntInterval) IntInterval) *IntArithmeticConstraint { + return &IntArithmeticConstraint{NewConstraint(y), a, b, op, fn} +} +func NewIntAddConstraint(a, b, y ssa.Value) Constraint { + return &IntAddConstraint{NewIntArithmeticConstraint(a, b, y, token.ADD, IntInterval.Add)} +} +func NewIntSubConstraint(a, b, y ssa.Value) Constraint { + return &IntSubConstraint{NewIntArithmeticConstraint(a, b, y, token.SUB, IntInterval.Sub)} +} +func NewIntMulConstraint(a, b, y ssa.Value) Constraint { + return &IntMulConstraint{NewIntArithmeticConstraint(a, b, y, token.MUL, IntInterval.Mul)} +} +func NewIntConversionConstraint(x, y ssa.Value) Constraint { + return &IntConversionConstraint{NewConstraint(y), x} +} +func NewIntIntersectionConstraint(a, b ssa.Value, op token.Token, ranges Ranges, y ssa.Value) Constraint { + return &IntIntersectionConstraint{ + aConstraint: NewConstraint(y), + ranges: ranges, + A: a, + B: b, + Op: op, + } +} +func NewIntIntervalConstraint(i IntInterval, y ssa.Value) Constraint { + return &IntIntervalConstraint{NewConstraint(y), i} +} + +func (c *IntArithmeticConstraint) Operands() []ssa.Value { return []ssa.Value{c.A, c.B} } +func (c *IntConversionConstraint) Operands() []ssa.Value { return []ssa.Value{c.X} } +func (c *IntIntersectionConstraint) Operands() []ssa.Value { return []ssa.Value{c.A} } +func (s *IntIntervalConstraint) Operands() []ssa.Value { return nil } + +func (c *IntArithmeticConstraint) String() string { + return fmt.Sprintf("%s = %s %s %s", c.Y().Name(), c.A.Name(), c.Op, c.B.Name()) +} +func (c *IntConversionConstraint) String() string { + return fmt.Sprintf("%s = %s(%s)", c.Y().Name(), c.Y().Type(), c.X.Name()) +} +func (c *IntIntersectionConstraint) String() string { + return fmt.Sprintf("%s = %s %s %s (%t branch)", c.Y().Name(), c.A.Name(), c.Op, c.B.Name(), c.Y().(*ssa.Sigma).Branch) +} +func (c *IntIntervalConstraint) String() string { return fmt.Sprintf("%s = %s", c.Y().Name(), c.I) } + +func (c *IntArithmeticConstraint) Eval(g *Graph) Range { + i1, i2 := g.Range(c.A).(IntInterval), g.Range(c.B).(IntInterval) + if !i1.IsKnown() || !i2.IsKnown() { + return IntInterval{} + } + return c.Fn(i1, i2) +} +func (c *IntConversionConstraint) Eval(g *Graph) Range { + s := &types.StdSizes{ + // XXX is it okay to assume the largest word size, or do we + // need to be platform specific? + WordSize: 8, + MaxAlign: 1, + } + fromI := g.Range(c.X).(IntInterval) + toI := g.Range(c.Y()).(IntInterval) + fromT := c.X.Type().Underlying().(*types.Basic) + toT := c.Y().Type().Underlying().(*types.Basic) + fromB := s.Sizeof(c.X.Type()) + toB := s.Sizeof(c.Y().Type()) + + if !fromI.IsKnown() { + return toI + } + if !toI.IsKnown() { + return fromI + } + + // uint -> sint/uint, M > N: [max(0, l1), min(2**N-1, u2)] + if (fromT.Info()&types.IsUnsigned != 0) && + toB > fromB { + + n := big.NewInt(1) + n.Lsh(n, uint(fromB*8)) + n.Sub(n, big.NewInt(1)) + return NewIntInterval( + MaxZ(NewZ(0), fromI.Lower), + MinZ(NewBigZ(n), toI.Upper), + ) + } + + // sint -> sint, M > N; [max(-∞, l1), min(2**N-1, u2)] + if (fromT.Info()&types.IsUnsigned == 0) && + (toT.Info()&types.IsUnsigned == 0) && + toB > fromB { + + n := big.NewInt(1) + n.Lsh(n, uint(fromB*8)) + n.Sub(n, big.NewInt(1)) + return NewIntInterval( + MaxZ(NInfinity, fromI.Lower), + MinZ(NewBigZ(n), toI.Upper), + ) + } + + return fromI +} +func (c *IntIntersectionConstraint) Eval(g *Graph) Range { + xi := g.Range(c.A).(IntInterval) + if !xi.IsKnown() { + return c.I + } + return xi.Intersection(c.I) +} +func (c *IntIntervalConstraint) Eval(*Graph) Range { return c.I } + +func (c *IntIntersectionConstraint) Futures() []ssa.Value { + return []ssa.Value{c.B} +} + +func (c *IntIntersectionConstraint) Resolve() { + r, ok := c.ranges[c.B].(IntInterval) + if !ok { + c.I = InfinityFor(c.Y()) + return + } + + switch c.Op { + case token.EQL: + c.I = r + case token.GTR: + c.I = NewIntInterval(r.Lower.Add(NewZ(1)), PInfinity) + case token.GEQ: + c.I = NewIntInterval(r.Lower, PInfinity) + case token.LSS: + // TODO(dh): do we need 0 instead of NInfinity for uints? + c.I = NewIntInterval(NInfinity, r.Upper.Sub(NewZ(1))) + case token.LEQ: + c.I = NewIntInterval(NInfinity, r.Upper) + case token.NEQ: + c.I = InfinityFor(c.Y()) + default: + panic("unsupported op " + c.Op.String()) + } +} + +func (c *IntIntersectionConstraint) IsKnown() bool { + return c.I.IsKnown() +} + +func (c *IntIntersectionConstraint) MarkUnresolved() { + c.resolved = false +} + +func (c *IntIntersectionConstraint) MarkResolved() { + c.resolved = true +} + +func (c *IntIntersectionConstraint) IsResolved() bool { + return c.resolved +} diff --git a/vendor/honnef.co/go/tools/staticcheck/vrp/slice.go b/vendor/honnef.co/go/tools/staticcheck/vrp/slice.go new file mode 100644 index 00000000000..40658dd8d86 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/vrp/slice.go @@ -0,0 +1,273 @@ +package vrp + +// TODO(dh): most of the constraints have implementations identical to +// that of strings. Consider reusing them. + +import ( + "fmt" + "go/types" + + "honnef.co/go/tools/ssa" +) + +type SliceInterval struct { + Length IntInterval +} + +func (s SliceInterval) Union(other Range) Range { + i, ok := other.(SliceInterval) + if !ok { + i = SliceInterval{EmptyIntInterval} + } + if s.Length.Empty() || !s.Length.IsKnown() { + return i + } + if i.Length.Empty() || !i.Length.IsKnown() { + return s + } + return SliceInterval{ + Length: s.Length.Union(i.Length).(IntInterval), + } +} +func (s SliceInterval) String() string { return s.Length.String() } +func (s SliceInterval) IsKnown() bool { return s.Length.IsKnown() } + +type SliceAppendConstraint struct { + aConstraint + A ssa.Value + B ssa.Value +} + +type SliceSliceConstraint struct { + aConstraint + X ssa.Value + Lower ssa.Value + Upper ssa.Value +} + +type ArraySliceConstraint struct { + aConstraint + X ssa.Value + Lower ssa.Value + Upper ssa.Value +} + +type SliceIntersectionConstraint struct { + aConstraint + X ssa.Value + I IntInterval +} + +type SliceLengthConstraint struct { + aConstraint + X ssa.Value +} + +type MakeSliceConstraint struct { + aConstraint + Size ssa.Value +} + +type SliceIntervalConstraint struct { + aConstraint + I IntInterval +} + +func NewSliceAppendConstraint(a, b, y ssa.Value) Constraint { + return &SliceAppendConstraint{NewConstraint(y), a, b} +} +func NewSliceSliceConstraint(x, lower, upper, y ssa.Value) Constraint { + return &SliceSliceConstraint{NewConstraint(y), x, lower, upper} +} +func NewArraySliceConstraint(x, lower, upper, y ssa.Value) Constraint { + return &ArraySliceConstraint{NewConstraint(y), x, lower, upper} +} +func NewSliceIntersectionConstraint(x ssa.Value, i IntInterval, y ssa.Value) Constraint { + return &SliceIntersectionConstraint{NewConstraint(y), x, i} +} +func NewSliceLengthConstraint(x, y ssa.Value) Constraint { + return &SliceLengthConstraint{NewConstraint(y), x} +} +func NewMakeSliceConstraint(size, y ssa.Value) Constraint { + return &MakeSliceConstraint{NewConstraint(y), size} +} +func NewSliceIntervalConstraint(i IntInterval, y ssa.Value) Constraint { + return &SliceIntervalConstraint{NewConstraint(y), i} +} + +func (c *SliceAppendConstraint) Operands() []ssa.Value { return []ssa.Value{c.A, c.B} } +func (c *SliceSliceConstraint) Operands() []ssa.Value { + ops := []ssa.Value{c.X} + if c.Lower != nil { + ops = append(ops, c.Lower) + } + if c.Upper != nil { + ops = append(ops, c.Upper) + } + return ops +} +func (c *ArraySliceConstraint) Operands() []ssa.Value { + ops := []ssa.Value{c.X} + if c.Lower != nil { + ops = append(ops, c.Lower) + } + if c.Upper != nil { + ops = append(ops, c.Upper) + } + return ops +} +func (c *SliceIntersectionConstraint) Operands() []ssa.Value { return []ssa.Value{c.X} } +func (c *SliceLengthConstraint) Operands() []ssa.Value { return []ssa.Value{c.X} } +func (c *MakeSliceConstraint) Operands() []ssa.Value { return []ssa.Value{c.Size} } +func (s *SliceIntervalConstraint) Operands() []ssa.Value { return nil } + +func (c *SliceAppendConstraint) String() string { + return fmt.Sprintf("%s = append(%s, %s)", c.Y().Name(), c.A.Name(), c.B.Name()) +} +func (c *SliceSliceConstraint) String() string { + var lname, uname string + if c.Lower != nil { + lname = c.Lower.Name() + } + if c.Upper != nil { + uname = c.Upper.Name() + } + return fmt.Sprintf("%s[%s:%s]", c.X.Name(), lname, uname) +} +func (c *ArraySliceConstraint) String() string { + var lname, uname string + if c.Lower != nil { + lname = c.Lower.Name() + } + if c.Upper != nil { + uname = c.Upper.Name() + } + return fmt.Sprintf("%s[%s:%s]", c.X.Name(), lname, uname) +} +func (c *SliceIntersectionConstraint) String() string { + return fmt.Sprintf("%s = %s.%t ⊓ %s", c.Y().Name(), c.X.Name(), c.Y().(*ssa.Sigma).Branch, c.I) +} +func (c *SliceLengthConstraint) String() string { + return fmt.Sprintf("%s = len(%s)", c.Y().Name(), c.X.Name()) +} +func (c *MakeSliceConstraint) String() string { + return fmt.Sprintf("%s = make(slice, %s)", c.Y().Name(), c.Size.Name()) +} +func (c *SliceIntervalConstraint) String() string { return fmt.Sprintf("%s = %s", c.Y().Name(), c.I) } + +func (c *SliceAppendConstraint) Eval(g *Graph) Range { + l1 := g.Range(c.A).(SliceInterval).Length + var l2 IntInterval + switch r := g.Range(c.B).(type) { + case SliceInterval: + l2 = r.Length + case StringInterval: + l2 = r.Length + default: + return SliceInterval{} + } + if !l1.IsKnown() || !l2.IsKnown() { + return SliceInterval{} + } + return SliceInterval{ + Length: l1.Add(l2), + } +} +func (c *SliceSliceConstraint) Eval(g *Graph) Range { + lr := NewIntInterval(NewZ(0), NewZ(0)) + if c.Lower != nil { + lr = g.Range(c.Lower).(IntInterval) + } + ur := g.Range(c.X).(SliceInterval).Length + if c.Upper != nil { + ur = g.Range(c.Upper).(IntInterval) + } + if !lr.IsKnown() || !ur.IsKnown() { + return SliceInterval{} + } + + ls := []Z{ + ur.Lower.Sub(lr.Lower), + ur.Upper.Sub(lr.Lower), + ur.Lower.Sub(lr.Upper), + ur.Upper.Sub(lr.Upper), + } + // TODO(dh): if we don't truncate lengths to 0 we might be able to + // easily detect slices with high < low. we'd need to treat -∞ + // specially, though. + for i, l := range ls { + if l.Sign() == -1 { + ls[i] = NewZ(0) + } + } + + return SliceInterval{ + Length: NewIntInterval(MinZ(ls...), MaxZ(ls...)), + } +} +func (c *ArraySliceConstraint) Eval(g *Graph) Range { + lr := NewIntInterval(NewZ(0), NewZ(0)) + if c.Lower != nil { + lr = g.Range(c.Lower).(IntInterval) + } + var l int64 + switch typ := c.X.Type().(type) { + case *types.Array: + l = typ.Len() + case *types.Pointer: + l = typ.Elem().(*types.Array).Len() + } + ur := NewIntInterval(NewZ(l), NewZ(l)) + if c.Upper != nil { + ur = g.Range(c.Upper).(IntInterval) + } + if !lr.IsKnown() || !ur.IsKnown() { + return SliceInterval{} + } + + ls := []Z{ + ur.Lower.Sub(lr.Lower), + ur.Upper.Sub(lr.Lower), + ur.Lower.Sub(lr.Upper), + ur.Upper.Sub(lr.Upper), + } + // TODO(dh): if we don't truncate lengths to 0 we might be able to + // easily detect slices with high < low. we'd need to treat -∞ + // specially, though. + for i, l := range ls { + if l.Sign() == -1 { + ls[i] = NewZ(0) + } + } + + return SliceInterval{ + Length: NewIntInterval(MinZ(ls...), MaxZ(ls...)), + } +} +func (c *SliceIntersectionConstraint) Eval(g *Graph) Range { + xi := g.Range(c.X).(SliceInterval) + if !xi.IsKnown() { + return c.I + } + return SliceInterval{ + Length: xi.Length.Intersection(c.I), + } +} +func (c *SliceLengthConstraint) Eval(g *Graph) Range { + i := g.Range(c.X).(SliceInterval).Length + if !i.IsKnown() { + return NewIntInterval(NewZ(0), PInfinity) + } + return i +} +func (c *MakeSliceConstraint) Eval(g *Graph) Range { + i, ok := g.Range(c.Size).(IntInterval) + if !ok { + return SliceInterval{NewIntInterval(NewZ(0), PInfinity)} + } + if i.Lower.Sign() == -1 { + i.Lower = NewZ(0) + } + return SliceInterval{i} +} +func (c *SliceIntervalConstraint) Eval(*Graph) Range { return SliceInterval{c.I} } diff --git a/vendor/honnef.co/go/tools/staticcheck/vrp/string.go b/vendor/honnef.co/go/tools/staticcheck/vrp/string.go new file mode 100644 index 00000000000..e05877f9f78 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/vrp/string.go @@ -0,0 +1,258 @@ +package vrp + +import ( + "fmt" + "go/token" + "go/types" + + "honnef.co/go/tools/ssa" +) + +type StringInterval struct { + Length IntInterval +} + +func (s StringInterval) Union(other Range) Range { + i, ok := other.(StringInterval) + if !ok { + i = StringInterval{EmptyIntInterval} + } + if s.Length.Empty() || !s.Length.IsKnown() { + return i + } + if i.Length.Empty() || !i.Length.IsKnown() { + return s + } + return StringInterval{ + Length: s.Length.Union(i.Length).(IntInterval), + } +} + +func (s StringInterval) String() string { + return s.Length.String() +} + +func (s StringInterval) IsKnown() bool { + return s.Length.IsKnown() +} + +type StringSliceConstraint struct { + aConstraint + X ssa.Value + Lower ssa.Value + Upper ssa.Value +} + +type StringIntersectionConstraint struct { + aConstraint + ranges Ranges + A ssa.Value + B ssa.Value + Op token.Token + I IntInterval + resolved bool +} + +type StringConcatConstraint struct { + aConstraint + A ssa.Value + B ssa.Value +} + +type StringLengthConstraint struct { + aConstraint + X ssa.Value +} + +type StringIntervalConstraint struct { + aConstraint + I IntInterval +} + +func NewStringSliceConstraint(x, lower, upper, y ssa.Value) Constraint { + return &StringSliceConstraint{NewConstraint(y), x, lower, upper} +} +func NewStringIntersectionConstraint(a, b ssa.Value, op token.Token, ranges Ranges, y ssa.Value) Constraint { + return &StringIntersectionConstraint{ + aConstraint: NewConstraint(y), + ranges: ranges, + A: a, + B: b, + Op: op, + } +} +func NewStringConcatConstraint(a, b, y ssa.Value) Constraint { + return &StringConcatConstraint{NewConstraint(y), a, b} +} +func NewStringLengthConstraint(x ssa.Value, y ssa.Value) Constraint { + return &StringLengthConstraint{NewConstraint(y), x} +} +func NewStringIntervalConstraint(i IntInterval, y ssa.Value) Constraint { + return &StringIntervalConstraint{NewConstraint(y), i} +} + +func (c *StringSliceConstraint) Operands() []ssa.Value { + vs := []ssa.Value{c.X} + if c.Lower != nil { + vs = append(vs, c.Lower) + } + if c.Upper != nil { + vs = append(vs, c.Upper) + } + return vs +} +func (c *StringIntersectionConstraint) Operands() []ssa.Value { return []ssa.Value{c.A} } +func (c StringConcatConstraint) Operands() []ssa.Value { return []ssa.Value{c.A, c.B} } +func (c *StringLengthConstraint) Operands() []ssa.Value { return []ssa.Value{c.X} } +func (s *StringIntervalConstraint) Operands() []ssa.Value { return nil } + +func (c *StringSliceConstraint) String() string { + var lname, uname string + if c.Lower != nil { + lname = c.Lower.Name() + } + if c.Upper != nil { + uname = c.Upper.Name() + } + return fmt.Sprintf("%s[%s:%s]", c.X.Name(), lname, uname) +} +func (c *StringIntersectionConstraint) String() string { + return fmt.Sprintf("%s = %s %s %s (%t branch)", c.Y().Name(), c.A.Name(), c.Op, c.B.Name(), c.Y().(*ssa.Sigma).Branch) +} +func (c StringConcatConstraint) String() string { + return fmt.Sprintf("%s = %s + %s", c.Y().Name(), c.A.Name(), c.B.Name()) +} +func (c *StringLengthConstraint) String() string { + return fmt.Sprintf("%s = len(%s)", c.Y().Name(), c.X.Name()) +} +func (c *StringIntervalConstraint) String() string { return fmt.Sprintf("%s = %s", c.Y().Name(), c.I) } + +func (c *StringSliceConstraint) Eval(g *Graph) Range { + lr := NewIntInterval(NewZ(0), NewZ(0)) + if c.Lower != nil { + lr = g.Range(c.Lower).(IntInterval) + } + ur := g.Range(c.X).(StringInterval).Length + if c.Upper != nil { + ur = g.Range(c.Upper).(IntInterval) + } + if !lr.IsKnown() || !ur.IsKnown() { + return StringInterval{} + } + + ls := []Z{ + ur.Lower.Sub(lr.Lower), + ur.Upper.Sub(lr.Lower), + ur.Lower.Sub(lr.Upper), + ur.Upper.Sub(lr.Upper), + } + // TODO(dh): if we don't truncate lengths to 0 we might be able to + // easily detect slices with high < low. we'd need to treat -∞ + // specially, though. + for i, l := range ls { + if l.Sign() == -1 { + ls[i] = NewZ(0) + } + } + + return StringInterval{ + Length: NewIntInterval(MinZ(ls...), MaxZ(ls...)), + } +} +func (c *StringIntersectionConstraint) Eval(g *Graph) Range { + var l IntInterval + switch r := g.Range(c.A).(type) { + case StringInterval: + l = r.Length + case IntInterval: + l = r + } + + if !l.IsKnown() { + return StringInterval{c.I} + } + return StringInterval{ + Length: l.Intersection(c.I), + } +} +func (c StringConcatConstraint) Eval(g *Graph) Range { + i1, i2 := g.Range(c.A).(StringInterval), g.Range(c.B).(StringInterval) + if !i1.Length.IsKnown() || !i2.Length.IsKnown() { + return StringInterval{} + } + return StringInterval{ + Length: i1.Length.Add(i2.Length), + } +} +func (c *StringLengthConstraint) Eval(g *Graph) Range { + i := g.Range(c.X).(StringInterval).Length + if !i.IsKnown() { + return NewIntInterval(NewZ(0), PInfinity) + } + return i +} +func (c *StringIntervalConstraint) Eval(*Graph) Range { return StringInterval{c.I} } + +func (c *StringIntersectionConstraint) Futures() []ssa.Value { + return []ssa.Value{c.B} +} + +func (c *StringIntersectionConstraint) Resolve() { + if (c.A.Type().Underlying().(*types.Basic).Info() & types.IsString) != 0 { + // comparing two strings + r, ok := c.ranges[c.B].(StringInterval) + if !ok { + c.I = NewIntInterval(NewZ(0), PInfinity) + return + } + switch c.Op { + case token.EQL: + c.I = r.Length + case token.GTR, token.GEQ: + c.I = NewIntInterval(r.Length.Lower, PInfinity) + case token.LSS, token.LEQ: + c.I = NewIntInterval(NewZ(0), r.Length.Upper) + case token.NEQ: + default: + panic("unsupported op " + c.Op.String()) + } + } else { + r, ok := c.ranges[c.B].(IntInterval) + if !ok { + c.I = NewIntInterval(NewZ(0), PInfinity) + return + } + // comparing two lengths + switch c.Op { + case token.EQL: + c.I = r + case token.GTR: + c.I = NewIntInterval(r.Lower.Add(NewZ(1)), PInfinity) + case token.GEQ: + c.I = NewIntInterval(r.Lower, PInfinity) + case token.LSS: + c.I = NewIntInterval(NInfinity, r.Upper.Sub(NewZ(1))) + case token.LEQ: + c.I = NewIntInterval(NInfinity, r.Upper) + case token.NEQ: + default: + panic("unsupported op " + c.Op.String()) + } + } +} + +func (c *StringIntersectionConstraint) IsKnown() bool { + return c.I.IsKnown() +} + +func (c *StringIntersectionConstraint) MarkUnresolved() { + c.resolved = false +} + +func (c *StringIntersectionConstraint) MarkResolved() { + c.resolved = true +} + +func (c *StringIntersectionConstraint) IsResolved() bool { + return c.resolved +} diff --git a/vendor/honnef.co/go/tools/staticcheck/vrp/vrp.go b/vendor/honnef.co/go/tools/staticcheck/vrp/vrp.go new file mode 100644 index 00000000000..cb17f042ac3 --- /dev/null +++ b/vendor/honnef.co/go/tools/staticcheck/vrp/vrp.go @@ -0,0 +1,1049 @@ +package vrp + +// TODO(dh) widening and narrowing have a lot of code in common. Make +// it reusable. + +import ( + "fmt" + "go/constant" + "go/token" + "go/types" + "math/big" + "sort" + "strings" + + "honnef.co/go/tools/ssa" +) + +type Future interface { + Constraint + Futures() []ssa.Value + Resolve() + IsKnown() bool + MarkUnresolved() + MarkResolved() + IsResolved() bool +} + +type Range interface { + Union(other Range) Range + IsKnown() bool +} + +type Constraint interface { + Y() ssa.Value + isConstraint() + String() string + Eval(*Graph) Range + Operands() []ssa.Value +} + +type aConstraint struct { + y ssa.Value +} + +func NewConstraint(y ssa.Value) aConstraint { + return aConstraint{y} +} + +func (aConstraint) isConstraint() {} +func (c aConstraint) Y() ssa.Value { return c.y } + +type PhiConstraint struct { + aConstraint + Vars []ssa.Value +} + +func NewPhiConstraint(vars []ssa.Value, y ssa.Value) Constraint { + uniqm := map[ssa.Value]struct{}{} + for _, v := range vars { + uniqm[v] = struct{}{} + } + var uniq []ssa.Value + for v := range uniqm { + uniq = append(uniq, v) + } + return &PhiConstraint{ + aConstraint: NewConstraint(y), + Vars: uniq, + } +} + +func (c *PhiConstraint) Operands() []ssa.Value { + return c.Vars +} + +func (c *PhiConstraint) Eval(g *Graph) Range { + i := Range(nil) + for _, v := range c.Vars { + i = g.Range(v).Union(i) + } + return i +} + +func (c *PhiConstraint) String() string { + names := make([]string, len(c.Vars)) + for i, v := range c.Vars { + names[i] = v.Name() + } + return fmt.Sprintf("%s = φ(%s)", c.Y().Name(), strings.Join(names, ", ")) +} + +func isSupportedType(typ types.Type) bool { + switch typ := typ.Underlying().(type) { + case *types.Basic: + switch typ.Kind() { + case types.String, types.UntypedString: + return true + default: + if (typ.Info() & types.IsInteger) == 0 { + return false + } + } + case *types.Chan: + return true + case *types.Slice: + return true + default: + return false + } + return true +} + +func ConstantToZ(c constant.Value) Z { + s := constant.ToInt(c).ExactString() + n := &big.Int{} + n.SetString(s, 10) + return NewBigZ(n) +} + +func sigmaInteger(g *Graph, ins *ssa.Sigma, cond *ssa.BinOp, ops []*ssa.Value) Constraint { + op := cond.Op + if !ins.Branch { + op = (invertToken(op)) + } + + switch op { + case token.EQL, token.GTR, token.GEQ, token.LSS, token.LEQ: + default: + return nil + } + var a, b ssa.Value + if (*ops[0]) == ins.X { + a = *ops[0] + b = *ops[1] + } else { + a = *ops[1] + b = *ops[0] + op = flipToken(op) + } + return NewIntIntersectionConstraint(a, b, op, g.ranges, ins) +} + +func sigmaString(g *Graph, ins *ssa.Sigma, cond *ssa.BinOp, ops []*ssa.Value) Constraint { + op := cond.Op + if !ins.Branch { + op = (invertToken(op)) + } + + switch op { + case token.EQL, token.GTR, token.GEQ, token.LSS, token.LEQ: + default: + return nil + } + + if ((*ops[0]).Type().Underlying().(*types.Basic).Info() & types.IsString) == 0 { + var a, b ssa.Value + call, ok := (*ops[0]).(*ssa.Call) + if ok && call.Common().Args[0] == ins.X { + a = *ops[0] + b = *ops[1] + } else { + a = *ops[1] + b = *ops[0] + op = flipToken(op) + } + return NewStringIntersectionConstraint(a, b, op, g.ranges, ins) + } + var a, b ssa.Value + if (*ops[0]) == ins.X { + a = *ops[0] + b = *ops[1] + } else { + a = *ops[1] + b = *ops[0] + op = flipToken(op) + } + return NewStringIntersectionConstraint(a, b, op, g.ranges, ins) +} + +func sigmaSlice(g *Graph, ins *ssa.Sigma, cond *ssa.BinOp, ops []*ssa.Value) Constraint { + // TODO(dh) sigmaSlice and sigmaString are a lot alike. Can they + // be merged? + // + // XXX support futures + + op := cond.Op + if !ins.Branch { + op = (invertToken(op)) + } + + k, ok := (*ops[1]).(*ssa.Const) + // XXX investigate in what cases this wouldn't be a Const + // + // XXX what if left and right are swapped? + if !ok { + return nil + } + + call, ok := (*ops[0]).(*ssa.Call) + if !ok { + return nil + } + builtin, ok := call.Common().Value.(*ssa.Builtin) + if !ok { + return nil + } + if builtin.Name() != "len" { + return nil + } + callops := call.Operands(nil) + + v := ConstantToZ(k.Value) + c := NewSliceIntersectionConstraint(*callops[1], IntInterval{}, ins).(*SliceIntersectionConstraint) + switch op { + case token.EQL: + c.I = NewIntInterval(v, v) + case token.GTR, token.GEQ: + off := int64(0) + if cond.Op == token.GTR { + off = 1 + } + c.I = NewIntInterval( + v.Add(NewZ(off)), + PInfinity, + ) + case token.LSS, token.LEQ: + off := int64(0) + if cond.Op == token.LSS { + off = -1 + } + c.I = NewIntInterval( + NInfinity, + v.Add(NewZ(off)), + ) + default: + return nil + } + return c +} + +func BuildGraph(f *ssa.Function) *Graph { + g := &Graph{ + Vertices: map[interface{}]*Vertex{}, + ranges: Ranges{}, + } + + var cs []Constraint + + ops := make([]*ssa.Value, 16) + seen := map[ssa.Value]bool{} + for _, block := range f.Blocks { + for _, ins := range block.Instrs { + ops = ins.Operands(ops[:0]) + for _, op := range ops { + if c, ok := (*op).(*ssa.Const); ok { + if seen[c] { + continue + } + seen[c] = true + if c.Value == nil { + switch c.Type().Underlying().(type) { + case *types.Slice: + cs = append(cs, NewSliceIntervalConstraint(NewIntInterval(NewZ(0), NewZ(0)), c)) + } + continue + } + switch c.Value.Kind() { + case constant.Int: + v := ConstantToZ(c.Value) + cs = append(cs, NewIntIntervalConstraint(NewIntInterval(v, v), c)) + case constant.String: + s := constant.StringVal(c.Value) + n := NewZ(int64(len(s))) + cs = append(cs, NewStringIntervalConstraint(NewIntInterval(n, n), c)) + } + } + } + } + } + for _, block := range f.Blocks { + for _, ins := range block.Instrs { + switch ins := ins.(type) { + case *ssa.Convert: + switch v := ins.Type().Underlying().(type) { + case *types.Basic: + if (v.Info() & types.IsInteger) == 0 { + continue + } + cs = append(cs, NewIntConversionConstraint(ins.X, ins)) + } + case *ssa.Call: + if static := ins.Common().StaticCallee(); static != nil { + if fn, ok := static.Object().(*types.Func); ok { + switch fn.FullName() { + case "bytes.Index", "bytes.IndexAny", "bytes.IndexByte", + "bytes.IndexFunc", "bytes.IndexRune", "bytes.LastIndex", + "bytes.LastIndexAny", "bytes.LastIndexByte", "bytes.LastIndexFunc", + "strings.Index", "strings.IndexAny", "strings.IndexByte", + "strings.IndexFunc", "strings.IndexRune", "strings.LastIndex", + "strings.LastIndexAny", "strings.LastIndexByte", "strings.LastIndexFunc": + // TODO(dh): instead of limiting by +∞, + // limit by the upper bound of the passed + // string + cs = append(cs, NewIntIntervalConstraint(NewIntInterval(NewZ(-1), PInfinity), ins)) + case "bytes.Title", "bytes.ToLower", "bytes.ToTitle", "bytes.ToUpper", + "strings.Title", "strings.ToLower", "strings.ToTitle", "strings.ToUpper": + cs = append(cs, NewCopyConstraint(ins.Common().Args[0], ins)) + case "bytes.ToLowerSpecial", "bytes.ToTitleSpecial", "bytes.ToUpperSpecial", + "strings.ToLowerSpecial", "strings.ToTitleSpecial", "strings.ToUpperSpecial": + cs = append(cs, NewCopyConstraint(ins.Common().Args[1], ins)) + case "bytes.Compare", "strings.Compare": + cs = append(cs, NewIntIntervalConstraint(NewIntInterval(NewZ(-1), NewZ(1)), ins)) + case "bytes.Count", "strings.Count": + // TODO(dh): instead of limiting by +∞, + // limit by the upper bound of the passed + // string. + cs = append(cs, NewIntIntervalConstraint(NewIntInterval(NewZ(0), PInfinity), ins)) + case "bytes.Map", "bytes.TrimFunc", "bytes.TrimLeft", "bytes.TrimLeftFunc", + "bytes.TrimRight", "bytes.TrimRightFunc", "bytes.TrimSpace", + "strings.Map", "strings.TrimFunc", "strings.TrimLeft", "strings.TrimLeftFunc", + "strings.TrimRight", "strings.TrimRightFunc", "strings.TrimSpace": + // TODO(dh): lower = 0, upper = upper of passed string + case "bytes.TrimPrefix", "bytes.TrimSuffix", + "strings.TrimPrefix", "strings.TrimSuffix": + // TODO(dh) range between "unmodified" and len(cutset) removed + case "(*bytes.Buffer).Cap", "(*bytes.Buffer).Len", "(*bytes.Reader).Len", "(*bytes.Reader).Size": + cs = append(cs, NewIntIntervalConstraint(NewIntInterval(NewZ(0), PInfinity), ins)) + } + } + } + builtin, ok := ins.Common().Value.(*ssa.Builtin) + ops := ins.Operands(nil) + if !ok { + continue + } + switch builtin.Name() { + case "len": + switch op1 := (*ops[1]).Type().Underlying().(type) { + case *types.Basic: + if op1.Kind() == types.String || op1.Kind() == types.UntypedString { + cs = append(cs, NewStringLengthConstraint(*ops[1], ins)) + } + case *types.Slice: + cs = append(cs, NewSliceLengthConstraint(*ops[1], ins)) + } + + case "append": + cs = append(cs, NewSliceAppendConstraint(ins.Common().Args[0], ins.Common().Args[1], ins)) + } + case *ssa.BinOp: + ops := ins.Operands(nil) + basic, ok := (*ops[0]).Type().Underlying().(*types.Basic) + if !ok { + continue + } + switch basic.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64, + types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64, types.UntypedInt: + fns := map[token.Token]func(ssa.Value, ssa.Value, ssa.Value) Constraint{ + token.ADD: NewIntAddConstraint, + token.SUB: NewIntSubConstraint, + token.MUL: NewIntMulConstraint, + // XXX support QUO, REM, SHL, SHR + } + fn, ok := fns[ins.Op] + if ok { + cs = append(cs, fn(*ops[0], *ops[1], ins)) + } + case types.String, types.UntypedString: + if ins.Op == token.ADD { + cs = append(cs, NewStringConcatConstraint(*ops[0], *ops[1], ins)) + } + } + case *ssa.Slice: + typ := ins.X.Type().Underlying() + switch typ := typ.(type) { + case *types.Basic: + cs = append(cs, NewStringSliceConstraint(ins.X, ins.Low, ins.High, ins)) + case *types.Slice: + cs = append(cs, NewSliceSliceConstraint(ins.X, ins.Low, ins.High, ins)) + case *types.Array: + cs = append(cs, NewArraySliceConstraint(ins.X, ins.Low, ins.High, ins)) + case *types.Pointer: + if _, ok := typ.Elem().(*types.Array); !ok { + continue + } + cs = append(cs, NewArraySliceConstraint(ins.X, ins.Low, ins.High, ins)) + } + case *ssa.Phi: + if !isSupportedType(ins.Type()) { + continue + } + ops := ins.Operands(nil) + dops := make([]ssa.Value, len(ops)) + for i, op := range ops { + dops[i] = *op + } + cs = append(cs, NewPhiConstraint(dops, ins)) + case *ssa.Sigma: + pred := ins.Block().Preds[0] + instrs := pred.Instrs + cond, ok := instrs[len(instrs)-1].(*ssa.If).Cond.(*ssa.BinOp) + ops := cond.Operands(nil) + if !ok { + continue + } + switch typ := ins.Type().Underlying().(type) { + case *types.Basic: + var c Constraint + switch typ.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64, + types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64, types.UntypedInt: + c = sigmaInteger(g, ins, cond, ops) + case types.String, types.UntypedString: + c = sigmaString(g, ins, cond, ops) + } + if c != nil { + cs = append(cs, c) + } + case *types.Slice: + c := sigmaSlice(g, ins, cond, ops) + if c != nil { + cs = append(cs, c) + } + default: + //log.Printf("unsupported sigma type %T", typ) // XXX + } + case *ssa.MakeChan: + cs = append(cs, NewMakeChannelConstraint(ins.Size, ins)) + case *ssa.MakeSlice: + cs = append(cs, NewMakeSliceConstraint(ins.Len, ins)) + case *ssa.ChangeType: + switch ins.X.Type().Underlying().(type) { + case *types.Chan: + cs = append(cs, NewChannelChangeTypeConstraint(ins.X, ins)) + } + } + } + } + + for _, c := range cs { + if c == nil { + panic("nil constraint") + } + // If V is used in constraint C, then we create an edge V->C + for _, op := range c.Operands() { + g.AddEdge(op, c, false) + } + if c, ok := c.(Future); ok { + for _, op := range c.Futures() { + g.AddEdge(op, c, true) + } + } + // If constraint C defines variable V, then we create an edge + // C->V + g.AddEdge(c, c.Y(), false) + } + + g.FindSCCs() + g.sccEdges = make([][]Edge, len(g.SCCs)) + g.futures = make([][]Future, len(g.SCCs)) + for _, e := range g.Edges { + g.sccEdges[e.From.SCC] = append(g.sccEdges[e.From.SCC], e) + if !e.control { + continue + } + if c, ok := e.To.Value.(Future); ok { + g.futures[e.From.SCC] = append(g.futures[e.From.SCC], c) + } + } + return g +} + +func (g *Graph) Solve() Ranges { + var consts []Z + off := NewZ(1) + for _, n := range g.Vertices { + if c, ok := n.Value.(*ssa.Const); ok { + basic, ok := c.Type().Underlying().(*types.Basic) + if !ok { + continue + } + if (basic.Info() & types.IsInteger) != 0 { + z := ConstantToZ(c.Value) + consts = append(consts, z) + consts = append(consts, z.Add(off)) + consts = append(consts, z.Sub(off)) + } + } + + } + sort.Sort(Zs(consts)) + + for scc, vertices := range g.SCCs { + n := 0 + n = len(vertices) + if n == 1 { + g.resolveFutures(scc) + v := vertices[0] + if v, ok := v.Value.(ssa.Value); ok { + switch typ := v.Type().Underlying().(type) { + case *types.Basic: + switch typ.Kind() { + case types.String, types.UntypedString: + if !g.Range(v).(StringInterval).IsKnown() { + g.SetRange(v, StringInterval{NewIntInterval(NewZ(0), PInfinity)}) + } + default: + if !g.Range(v).(IntInterval).IsKnown() { + g.SetRange(v, InfinityFor(v)) + } + } + case *types.Chan: + if !g.Range(v).(ChannelInterval).IsKnown() { + g.SetRange(v, ChannelInterval{NewIntInterval(NewZ(0), PInfinity)}) + } + case *types.Slice: + if !g.Range(v).(SliceInterval).IsKnown() { + g.SetRange(v, SliceInterval{NewIntInterval(NewZ(0), PInfinity)}) + } + } + } + if c, ok := v.Value.(Constraint); ok { + g.SetRange(c.Y(), c.Eval(g)) + } + } else { + uses := g.uses(scc) + entries := g.entries(scc) + for len(entries) > 0 { + v := entries[len(entries)-1] + entries = entries[:len(entries)-1] + for _, use := range uses[v] { + if g.widen(use, consts) { + entries = append(entries, use.Y()) + } + } + } + + g.resolveFutures(scc) + + // XXX this seems to be necessary, but shouldn't be. + // removing it leads to nil pointer derefs; investigate + // where we're not setting values correctly. + for _, n := range vertices { + if v, ok := n.Value.(ssa.Value); ok { + i, ok := g.Range(v).(IntInterval) + if !ok { + continue + } + if !i.IsKnown() { + g.SetRange(v, InfinityFor(v)) + } + } + } + + actives := g.actives(scc) + for len(actives) > 0 { + v := actives[len(actives)-1] + actives = actives[:len(actives)-1] + for _, use := range uses[v] { + if g.narrow(use) { + actives = append(actives, use.Y()) + } + } + } + } + // propagate scc + for _, edge := range g.sccEdges[scc] { + if edge.control { + continue + } + if edge.From.SCC == edge.To.SCC { + continue + } + if c, ok := edge.To.Value.(Constraint); ok { + g.SetRange(c.Y(), c.Eval(g)) + } + if c, ok := edge.To.Value.(Future); ok { + if !c.IsKnown() { + c.MarkUnresolved() + } + } + } + } + + for v, r := range g.ranges { + i, ok := r.(IntInterval) + if !ok { + continue + } + if (v.Type().Underlying().(*types.Basic).Info() & types.IsUnsigned) == 0 { + if i.Upper != PInfinity { + s := &types.StdSizes{ + // XXX is it okay to assume the largest word size, or do we + // need to be platform specific? + WordSize: 8, + MaxAlign: 1, + } + bits := (s.Sizeof(v.Type()) * 8) - 1 + n := big.NewInt(1) + n = n.Lsh(n, uint(bits)) + upper, lower := &big.Int{}, &big.Int{} + upper.Sub(n, big.NewInt(1)) + lower.Neg(n) + + if i.Upper.Cmp(NewBigZ(upper)) == 1 { + i = NewIntInterval(NInfinity, PInfinity) + } else if i.Lower.Cmp(NewBigZ(lower)) == -1 { + i = NewIntInterval(NInfinity, PInfinity) + } + } + } + + g.ranges[v] = i + } + + return g.ranges +} + +func VertexString(v *Vertex) string { + switch v := v.Value.(type) { + case Constraint: + return v.String() + case ssa.Value: + return v.Name() + case nil: + return "BUG: nil vertex value" + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } +} + +type Vertex struct { + Value interface{} // one of Constraint or ssa.Value + SCC int + index int + lowlink int + stack bool + + Succs []Edge +} + +type Ranges map[ssa.Value]Range + +func (r Ranges) Get(x ssa.Value) Range { + if x == nil { + return nil + } + i, ok := r[x] + if !ok { + switch x := x.Type().Underlying().(type) { + case *types.Basic: + switch x.Kind() { + case types.String, types.UntypedString: + return StringInterval{} + default: + return IntInterval{} + } + case *types.Chan: + return ChannelInterval{} + case *types.Slice: + return SliceInterval{} + } + } + return i +} + +type Graph struct { + Vertices map[interface{}]*Vertex + Edges []Edge + SCCs [][]*Vertex + ranges Ranges + + // map SCCs to futures + futures [][]Future + // map SCCs to edges + sccEdges [][]Edge +} + +func (g Graph) Graphviz() string { + var lines []string + lines = append(lines, "digraph{") + ids := map[interface{}]int{} + i := 1 + for _, v := range g.Vertices { + ids[v] = i + shape := "box" + if _, ok := v.Value.(ssa.Value); ok { + shape = "oval" + } + lines = append(lines, fmt.Sprintf(`n%d [shape="%s", label=%q, colorscheme=spectral11, style="filled", fillcolor="%d"]`, + i, shape, VertexString(v), (v.SCC%11)+1)) + i++ + } + for _, e := range g.Edges { + style := "solid" + if e.control { + style = "dashed" + } + lines = append(lines, fmt.Sprintf(`n%d -> n%d [style="%s"]`, ids[e.From], ids[e.To], style)) + } + lines = append(lines, "}") + return strings.Join(lines, "\n") +} + +func (g *Graph) SetRange(x ssa.Value, r Range) { + g.ranges[x] = r +} + +func (g *Graph) Range(x ssa.Value) Range { + return g.ranges.Get(x) +} + +func (g *Graph) widen(c Constraint, consts []Z) bool { + setRange := func(i Range) { + g.SetRange(c.Y(), i) + } + widenIntInterval := func(oi, ni IntInterval) (IntInterval, bool) { + if !ni.IsKnown() { + return oi, false + } + nlc := NInfinity + nuc := PInfinity + for _, co := range consts { + if co.Cmp(ni.Lower) <= 0 { + nlc = co + break + } + } + for _, co := range consts { + if co.Cmp(ni.Upper) >= 0 { + nuc = co + break + } + } + + if !oi.IsKnown() { + return ni, true + } + if ni.Lower.Cmp(oi.Lower) == -1 && ni.Upper.Cmp(oi.Upper) == 1 { + return NewIntInterval(nlc, nuc), true + } + if ni.Lower.Cmp(oi.Lower) == -1 { + return NewIntInterval(nlc, oi.Upper), true + } + if ni.Upper.Cmp(oi.Upper) == 1 { + return NewIntInterval(oi.Lower, nuc), true + } + return oi, false + } + switch oi := g.Range(c.Y()).(type) { + case IntInterval: + ni := c.Eval(g).(IntInterval) + si, changed := widenIntInterval(oi, ni) + if changed { + setRange(si) + return true + } + return false + case StringInterval: + ni := c.Eval(g).(StringInterval) + si, changed := widenIntInterval(oi.Length, ni.Length) + if changed { + setRange(StringInterval{si}) + return true + } + return false + case SliceInterval: + ni := c.Eval(g).(SliceInterval) + si, changed := widenIntInterval(oi.Length, ni.Length) + if changed { + setRange(SliceInterval{si}) + return true + } + return false + default: + return false + } +} + +func (g *Graph) narrow(c Constraint) bool { + narrowIntInterval := func(oi, ni IntInterval) (IntInterval, bool) { + oLower := oi.Lower + oUpper := oi.Upper + nLower := ni.Lower + nUpper := ni.Upper + + if oLower == NInfinity && nLower != NInfinity { + return NewIntInterval(nLower, oUpper), true + } + if oUpper == PInfinity && nUpper != PInfinity { + return NewIntInterval(oLower, nUpper), true + } + if oLower.Cmp(nLower) == 1 { + return NewIntInterval(nLower, oUpper), true + } + if oUpper.Cmp(nUpper) == -1 { + return NewIntInterval(oLower, nUpper), true + } + return oi, false + } + switch oi := g.Range(c.Y()).(type) { + case IntInterval: + ni := c.Eval(g).(IntInterval) + si, changed := narrowIntInterval(oi, ni) + if changed { + g.SetRange(c.Y(), si) + return true + } + return false + case StringInterval: + ni := c.Eval(g).(StringInterval) + si, changed := narrowIntInterval(oi.Length, ni.Length) + if changed { + g.SetRange(c.Y(), StringInterval{si}) + return true + } + return false + case SliceInterval: + ni := c.Eval(g).(SliceInterval) + si, changed := narrowIntInterval(oi.Length, ni.Length) + if changed { + g.SetRange(c.Y(), SliceInterval{si}) + return true + } + return false + default: + return false + } +} + +func (g *Graph) resolveFutures(scc int) { + for _, c := range g.futures[scc] { + c.Resolve() + } +} + +func (g *Graph) entries(scc int) []ssa.Value { + var entries []ssa.Value + for _, n := range g.Vertices { + if n.SCC != scc { + continue + } + if v, ok := n.Value.(ssa.Value); ok { + // XXX avoid quadratic runtime + // + // XXX I cannot think of any code where the future and its + // variables aren't in the same SCC, in which case this + // code isn't very useful (the variables won't be resolved + // yet). Before we have a cross-SCC example, however, we + // can't really verify that this code is working + // correctly, or indeed doing anything useful. + for _, on := range g.Vertices { + if c, ok := on.Value.(Future); ok { + if c.Y() == v { + if !c.IsResolved() { + g.SetRange(c.Y(), c.Eval(g)) + c.MarkResolved() + } + break + } + } + } + if g.Range(v).IsKnown() { + entries = append(entries, v) + } + } + } + return entries +} + +func (g *Graph) uses(scc int) map[ssa.Value][]Constraint { + m := map[ssa.Value][]Constraint{} + for _, e := range g.sccEdges[scc] { + if e.control { + continue + } + if v, ok := e.From.Value.(ssa.Value); ok { + c := e.To.Value.(Constraint) + sink := c.Y() + if g.Vertices[sink].SCC == scc { + m[v] = append(m[v], c) + } + } + } + return m +} + +func (g *Graph) actives(scc int) []ssa.Value { + var actives []ssa.Value + for _, n := range g.Vertices { + if n.SCC != scc { + continue + } + if v, ok := n.Value.(ssa.Value); ok { + if _, ok := v.(*ssa.Const); !ok { + actives = append(actives, v) + } + } + } + return actives +} + +func (g *Graph) AddEdge(from, to interface{}, ctrl bool) { + vf, ok := g.Vertices[from] + if !ok { + vf = &Vertex{Value: from} + g.Vertices[from] = vf + } + vt, ok := g.Vertices[to] + if !ok { + vt = &Vertex{Value: to} + g.Vertices[to] = vt + } + e := Edge{From: vf, To: vt, control: ctrl} + g.Edges = append(g.Edges, e) + vf.Succs = append(vf.Succs, e) +} + +type Edge struct { + From, To *Vertex + control bool +} + +func (e Edge) String() string { + return fmt.Sprintf("%s -> %s", VertexString(e.From), VertexString(e.To)) +} + +func (g *Graph) FindSCCs() { + // use Tarjan to find the SCCs + + index := 1 + var s []*Vertex + + scc := 0 + var strongconnect func(v *Vertex) + strongconnect = func(v *Vertex) { + // set the depth index for v to the smallest unused index + v.index = index + v.lowlink = index + index++ + s = append(s, v) + v.stack = true + + for _, e := range v.Succs { + w := e.To + if w.index == 0 { + // successor w has not yet been visited; recurse on it + strongconnect(w) + if w.lowlink < v.lowlink { + v.lowlink = w.lowlink + } + } else if w.stack { + // successor w is in stack s and hence in the current scc + if w.index < v.lowlink { + v.lowlink = w.index + } + } + } + + if v.lowlink == v.index { + for { + w := s[len(s)-1] + s = s[:len(s)-1] + w.stack = false + w.SCC = scc + if w == v { + break + } + } + scc++ + } + } + for _, v := range g.Vertices { + if v.index == 0 { + strongconnect(v) + } + } + + g.SCCs = make([][]*Vertex, scc) + for _, n := range g.Vertices { + n.SCC = scc - n.SCC - 1 + g.SCCs[n.SCC] = append(g.SCCs[n.SCC], n) + } +} + +func invertToken(tok token.Token) token.Token { + switch tok { + case token.LSS: + return token.GEQ + case token.GTR: + return token.LEQ + case token.EQL: + return token.NEQ + case token.NEQ: + return token.EQL + case token.GEQ: + return token.LSS + case token.LEQ: + return token.GTR + default: + panic(fmt.Sprintf("unsupported token %s", tok)) + } +} + +func flipToken(tok token.Token) token.Token { + switch tok { + case token.LSS: + return token.GTR + case token.GTR: + return token.LSS + case token.EQL: + return token.EQL + case token.NEQ: + return token.NEQ + case token.GEQ: + return token.LEQ + case token.LEQ: + return token.GEQ + default: + panic(fmt.Sprintf("unsupported token %s", tok)) + } +} + +type CopyConstraint struct { + aConstraint + X ssa.Value +} + +func (c *CopyConstraint) String() string { + return fmt.Sprintf("%s = copy(%s)", c.Y().Name(), c.X.Name()) +} + +func (c *CopyConstraint) Eval(g *Graph) Range { + return g.Range(c.X) +} + +func (c *CopyConstraint) Operands() []ssa.Value { + return []ssa.Value{c.X} +} + +func NewCopyConstraint(x, y ssa.Value) Constraint { + return &CopyConstraint{ + aConstraint: aConstraint{ + y: y, + }, + X: x, + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index ecd88c2b9d6..9ee3496797f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -196,13 +196,20 @@ gopkg.in/testfixtures.v2 gopkg.in/yaml.v2 # honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3 honnef.co/go/tools/cmd/gosimple +honnef.co/go/tools/cmd/staticcheck honnef.co/go/tools/cmd/unused honnef.co/go/tools/lint/lintutil honnef.co/go/tools/simple +honnef.co/go/tools/staticcheck honnef.co/go/tools/unused honnef.co/go/tools/lint honnef.co/go/tools/version honnef.co/go/tools/internal/sharedcheck honnef.co/go/tools/lint/lintdsl +honnef.co/go/tools/deprecated +honnef.co/go/tools/functions honnef.co/go/tools/ssa +honnef.co/go/tools/staticcheck/vrp honnef.co/go/tools/ssa/ssautil +honnef.co/go/tools/callgraph +honnef.co/go/tools/callgraph/static