Mode eingebaut

This commit is contained in:
konrad 2017-08-31 18:37:43 +02:00
parent fc139036c0
commit f2d411d666
485 changed files with 141934 additions and 95401 deletions

View File

@ -11,4 +11,10 @@ TODO::
* ~~Kofis löschen~~
* ~~Logout~~
* Konfis auf der Frontseite mit Websockets updaten
* Mode hinzufügen: in der Config soll man zwischen Gemeinden und Konfis umschalten können, so dass entweder die Gemeinden oder Konfis gegeneinander spielen
* ~~Mode hinzufügen: in der Config soll man zwischen Gemeinden und Konfis umschalten können, so dass entweder die Gemeinden oder Konfis gegeneinander spielen~~
* Alles in Fenster packen
* Random Port (wenn man irgendwie den laufenden port finden kann)
* Front-Tabelle schöner
* Mini-Doku
* Inbetriebnahme
* Bedienung

40
add.go
View File

@ -6,6 +6,10 @@ import (
)
func addKonfi(c echo.Context) error {
//Config
SiteConf := initConfig()
//Datenbankverbindung aufbauen
db := DBinit()
@ -18,16 +22,36 @@ func addKonfi(c echo.Context) error {
//Wenn eingeloggt
if logged != nil {
kofi := new(Kofi)
kofi.Name = c.FormValue("name")
kofi.Gemeinde = c.FormValue("gemeinde")
// Mode nach Kofis
if SiteConf.Mode == 0 {
kofi := new(Kofi)
kofi.Name = c.FormValue("name")
kofi.Gemeinde = c.FormValue("gemeinde")
// Einfügen
_, err := db.Insert(kofi)
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
return c.JSON(http.StatusInternalServerError, Message{"Error."})
} else if SiteConf.Mode == 1 { // Mode nach Gemeinden
gemeinde := new(Gemeinde)
gemeinde.Name = c.FormValue("name")
_, err := db.Insert(gemeinde)
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
return c.JSON(http.StatusInternalServerError, Message{"Error."})
//Aktuelle Coins holen
_, err := db.Insert(kofi)
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
return c.JSON(http.StatusOK, Message{"Error."})
return c.JSON(http.StatusInternalServerError, Message{"Wrong Mode."})
} else {
return c.JSON(http.StatusOK, Message{"Login first."})
}

View File

@ -3,13 +3,18 @@ package main
import (
"github.com/labstack/echo"
"net/http"
"strconv"
)
type Loggedin struct {
type AdminInfos struct {
Loggedin bool
Mode int
}
func adminHandler(c echo.Context) error {
//Config
SiteConf := initConfig()
rw := c.Response()
r := c.Request()
@ -20,8 +25,8 @@ func adminHandler(c echo.Context) error {
loggedin := sess.Get("login")
if loggedin != nil {
return c.Render(http.StatusOK, "admin", Loggedin{true})
return c.Render(http.StatusOK, "admin_mode_" + strconv.Itoa(SiteConf.Mode), AdminInfos{true, SiteConf.Mode})
} else {
return c.Render(http.StatusOK, "admin", Loggedin{false})
return c.Render(http.StatusOK, "login", AdminInfos{false, SiteConf.Mode})
}
}

View File

@ -1,9 +1,13 @@
function getList() {
$.getJSON('/list?asc=asc', function (data) {
//console.log(data);
$("#konfis").html('');
$("#list").html('');
$.each(data, function (i, item) {
$("#konfis").append('<tr id="kcoins_row_' + item.ID + '"> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td id="kcoins_display_' + item.ID + '">' + item.KCoins + '</td><td><span class="ui action input" id="kcoins_container_' + item.ID + '"><input type="number" value="0" id="kcoins_' + item.ID + '" name="kcoins" autocomplete="off" /><button class="ui right labeled icon button green" onclick="updateCoins(\'' + item.ID + '\');"><i class="right dollar icon"></i>KonfiCoins Hinzufügen</button></span>&nbsp;&nbsp;&nbsp;&nbsp;<button class="ui button red" onclick="deleteKonfi(\'' + item.ID + '\');" id="kcoins_container_' + item.ID + '">Konfi Löschen</button></td></tr>');
if(item.Gemeinde !== undefined) {
$("#list").append('<tr id="kcoins_row_' + item.ID + '"> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td id="kcoins_display_' + item.ID + '">' + item.KCoins + '</td><td><span class="ui action input" id="kcoins_container_' + item.ID + '"><input type="number" value="0" id="kcoins_' + item.ID + '" name="kcoins" autocomplete="off" /><button class="ui right labeled icon button green" onclick="updateCoins(\'' + item.ID + '\');"><i class="right dollar icon"></i>KonfiCoins Hinzufügen</button></span>&nbsp;&nbsp;&nbsp;&nbsp;<button class="ui button red" onclick="deleteKonfi(\'' + item.ID + '\');" id="kcoins_container_' + item.ID + '">Konfi Löschen</button></td></tr>');
} else {
$("#list").append('<tr id="kcoins_row_' + item.ID + '"> <td>' + item.Name + '</td> <td id="kcoins_display_' + item.ID + '">' + item.KCoins + '</td><td><span class="ui action input" id="kcoins_container_' + item.ID + '"><input type="number" value="0" id="kcoins_' + item.ID + '" name="kcoins" autocomplete="off" /><button class="ui right labeled icon button green" onclick="updateCoins(\'' + item.ID + '\');"><i class="right dollar icon"></i>KonfiCoins Hinzufügen</button></span>&nbsp;&nbsp;&nbsp;&nbsp;<button class="ui button red" onclick="deleteGemeinde(\'' + item.ID + '\');" id="kcoins_container_' + item.ID + '">Gemeinde Löschen</button></td></tr>');
}
});
});
}
@ -15,7 +19,7 @@ function updateCoins(id) {
var addcoins = $('#kcoins_' + id).val();
if(addcoins != 0) {
$('#kcoins_container_' + id).addClass('disabled');
$('#coins_container_' + id).addClass('disabled');
$.ajax({
url: '/update',
@ -23,11 +27,11 @@ function updateCoins(id) {
data: 'id=' + id + '&addcoins=' + addcoins,
success: function (msg) {
console.log(msg);
$('#kcoins_container_' + id).removeClass('disabled');
$('#coins_container_' + id).removeClass('disabled');
if (msg.Message == 'success') {
$('#kcoins_' + id).val("0");
$('#kcoins_display_' + id).html(msg.Kofi.KCoins);
$('#kcoins_display_' + id).html(msg.Data.KCoins);
} else {
$('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>');
@ -72,6 +76,42 @@ function deleteKonfi(id) {
;
}
function deleteGemeinde(id) {
console.log('Delete', id);
$('#kcoins_container_' + id).addClass('disabled');
$('.ui.basic.gemeindedel.modal')
.modal({
closable : false,
duration: 200,
onDeny : function(){
$('#kcoins_container_' + id).removeClass('disabled');
return true;
},
onApprove : function() {
$.ajax({
url: '/delete',
method: 'POST',
data: 'id=' + id,
success: function (msg) {
console.log(msg);
if (msg.Message == 'success') {
//$('#kcoins_row_' + id).remove();
getList();
$('#msg').html('<div class="ui success message" style="display: block;">Die Gemeinde wurde erfolgreich gelöscht.</div>');
} else {
$('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>');
}
}
});
}
})
.modal('show')
;
}
// Konfi hinzufügen
$('.ui.kofiadd.modal')
.modal({
duration: 200,
@ -101,6 +141,35 @@ $('.ui.kofiadd.modal')
.modal('attach events', '.addKofi.button', 'show')
;
$('.ui.gemeindeadd.modal')
.modal({
duration: 200,
onApprove : function() {
$('.loader').addClass('active');
console.log('bul');
$.ajax({
url: '/add',
method: 'POST',
data: 'name=' + $('#name').val(),
success: function (msg) {
$('.loader').removeClass('active');
console.log(msg);
if (msg.Message == 'success') {
$('#name').val('');
getList();
$('#msg').html('<div class="ui success message" style="display: block;">Die Gemeinde wurde erfolgreich hinzugefügt.</div>');
} else {
$('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>');
}
}
});
}
})
.modal('attach events', '.addGemeinde.button', 'show')
;
$('.ui.kofiupload.modal')
.modal('attach events', '.ui.right.labeled.icon.uploadKofis.button.blue', 'show')
;

View File

@ -3,7 +3,11 @@ setInterval(function() {
//console.log(data);
$( "#konfis" ).html('');
$.each( data, function( i, item ) {
$( "#konfis" ).append('<tr> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td>' + item.KCoins + '</td></tr>');
if (item.Gemeinde != undefined) {
$( "#konfis" ).append('<tr> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td>' + item.KCoins + '</td></tr>');
} else {
$( "#konfis" ).append('<tr> <td>' + item.Name + '</td> <td>' + item.KCoins + '</td></tr>');
}
});
});
}, 1000);

View File

@ -10,6 +10,7 @@ type Configuration struct {
AdminPassword string
Interface string
DBFile string
Mode int
}
var SiteConf Configuration = Configuration{}

View File

@ -1,3 +1,12 @@
; Das Adminpasswort, wird benötigt, um sich unter /admin einzuloggen
AdminPassword = geheim
; 0 = Konfis sind selbstständig
; 1 = Gemeinden spielen gegeneinancer
Mode = 1
; Serverkram
; Das Interface inkl. Port, auf dem der Webserver läuft
Interface = :8080
; Hier wird die Datenbank gespeichert
DBFile = ./data.db

View File

@ -7,6 +7,10 @@ import (
)
func deleteKonfi(c echo.Context) error {
//Config
SiteConf := initConfig()
//Datenbankverbindung aufbauen
db := DBinit()
@ -22,12 +26,19 @@ func deleteKonfi(c echo.Context) error {
id, _ := strconv.Atoi(c.FormValue("id"))
//Löschen
_, err := db.Id(id).Delete(&Kofi{})
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
if SiteConf.Mode == 0 {
_, err := db.Id(id).Delete(&Kofi{})
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
} else if SiteConf.Mode == 1{
_, err := db.Id(id).Delete(&Gemeinde{})
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
}
return c.JSON(http.StatusOK, Message{"Error."})
return c.JSON(http.StatusInternalServerError, Message{"Error."})
} else {
return c.JSON(http.StatusOK, Message{"Login first."})
return c.JSON(http.StatusForbidden, Message{"Login first."})
}
}

View File

@ -10,21 +10,46 @@ func getList(c echo.Context) error {
//Datenbankverbindung aufbauen
db := DBinit()
//Daten holen und anzeigen
var kofi []Kofi
asc := c.QueryParam("asc")
if asc == "" {
err := db.OrderBy("KCoins DESC").Find(&kofi)
if err != nil {
fmt.Println(err)
//Config
SiteConf := initConfig()
if SiteConf.Mode == 0 {
//Daten holen und anzeigen
var kofi []Kofi
asc := c.QueryParam("asc")
if asc == "" {
err := db.OrderBy("KCoins DESC").Find(&kofi)
if err != nil {
fmt.Println(err)
}
} else {
err := db.OrderBy("Name ASC").Find(&kofi)
if err != nil {
fmt.Println(err)
}
}
} else {
err := db.OrderBy("Name ASC").Find(&kofi)
if err != nil {
fmt.Println(err)
//Template
return c.JSON(http.StatusOK, kofi)
} else if SiteConf.Mode == 1 {
//Daten holen und anzeigen
var gemeinden []Gemeinde
asc := c.QueryParam("asc")
if asc == "" {
err := db.OrderBy("KCoins DESC").Find(&gemeinden)
if err != nil {
fmt.Println(err)
}
} else {
err := db.OrderBy("Name ASC").Find(&gemeinden)
if err != nil {
fmt.Println(err)
}
}
//Template
return c.JSON(http.StatusOK, gemeinden)
}
//Template
return c.JSON(http.StatusOK, kofi)
return c.HTML(http.StatusInternalServerError, "Error. (Wrong mode)")
}

View File

@ -79,7 +79,8 @@ func main() {
//DB init - Create tables
db := DBinit()
db.CreateTables(&Kofi{})
db.Sync(&Kofi{})
db.Sync(&Gemeinde{})
//Start the server
e.Logger.SetLevel(log.ERROR)

View File

@ -6,6 +6,10 @@ import (
)
func showList(c echo.Context) error {
//Config
SiteConf := initConfig()
//Template
return c.Render(http.StatusOK, "index", Message{"schinken"})
return c.Render(http.StatusOK, "index", SiteConf)
}

View File

@ -1,4 +1,4 @@
{{define "admin"}}
{{define "admin_mode_0"}}
<!DOCTYPE html>
<html lang="de">
<head>
@ -37,7 +37,7 @@
<th>Bearbeiten</th>
</tr>
</thead>
<tbody id="konfis">
<tbody id="list">
<tr>
<td colspan="3">Laden...</td>
</tr>
@ -75,14 +75,12 @@
Konfi hinzufügen
</div>
<div class="image content">
<form action="#" method="post">
<div class="ui input">
<input type="text" id="name" placeholder="Name"/>
</div><br/><br/>
<div class="ui input">
<input type="text" id="gemeinde" placeholder="Gemeinde"/>
</div>
</form>
</div>
<div class="actions">
<div class="ui black deny button">
@ -108,41 +106,14 @@
<div class="ui black deny button">
Abbrechen
</div>
<div class="ui positive button">
<div class="ui positive button">include
Hochladen
</div>
</div>
</div>
{{else}}
<body style="background: url(/assets/bg.jpg) no-repeat center fixed">
<div class="ui middle aligned center aligned grid" style="width: 30em; margin: 37vh auto;">
<div class="column">
<h2 class="ui header">
Kasino Admin
</h2>
<form class="ui large form" id="loginform" method="post">
<div class="ui segment">
<div class="field">
<div class="ui left icon input">
<i class="lock icon"></i>
<input type="password" name="password" id="password" placeholder="Passwort" autofocus>
</div>
</div>
<div class="ui fluid large blue submit button">Login</div>
</div>
<div id="msg"></div>
</form>
</div>
</div>
{{end}}
{{if .Loggedin}}
<script src="/assets/js/admin.js"></script>
{{else}}
<script src="/assets/js/login.js"></script>
{{end}}
</body>
</html>
{{end}}
{{end}}include

95
tpl/admin_mode_1.html Normal file
View File

@ -0,0 +1,95 @@
{{define "admin_mode_1"}}
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Kasino Admin</title>
<link rel="stylesheet" type="text/css" href="/assets/semantic/semantic.min.css">
<script src="/assets/js/jquery-3.1.1.min.js"></script>
<script src="/assets/semantic/semantic.min.js"></script>
</head>
{{if .Loggedin}}
<body>
<div style="width: 98%; margin: 0 auto;">
<h1>Kasino Admin
<div class="ui inline loader"></div>
</h1>
<a href="/logout" class="ui right labeled icon button blue" style="float: right;">
<i class="right sign out icon"></i>
Ausloggen
</a>
<p>
<a class="ui right labeled icon addGemeinde button green">
<i class="right plus icon"></i>
Gemeinde Hinzufügen
</a>
</p>
<div id="msg"></div>
<table width="100%" border="0" cellpadding="0" cellspacing="0" class="ui celled table">
<thead>
<tr>
<th>Name</th>
<th>KonfiCoins</th>
<th>Bearbeiten</th>
</tr>
</thead>
<tbody id="list">
<tr>
<td colspan="3">Laden...</td>
</tr>
</tbody>
</table>
</div>
<!-- Modals -->
<!-- Gemeinde löschen -->
<div class="ui small basic gemeindedel modal">
<div class="ui icon header">
<i class="trash outline icon"></i>
Gemeinde löschen
</div>
<div class="content">
<p>Willst du diese Gemeinde wirklich löschen? Diese Aktion kann nicht Rückgängig gemacht werden!</p>
</div>
<div class="actions">
<div class="ui red basic cancel inverted button">
<i class="remove icon"></i>
Nein
</div>
<div class="ui green ok inverted button">
<i class="checkmark icon"></i>
Ja!
</div>
</div>
</div>
<!-- Gemeinde hinzufügen -->
<div class="ui gemeindeadd modal">
<i class="close icon"></i>
<div class="header">
Gemeinde hinzufügen
</div>
<div class="image content">
<div class="ui input">
<input type="text" id="name" placeholder="Name"/>
</div>
</div>
<div class="actions">
<div class="ui black deny button">
Abbrechen
</div>
<div class="ui positive right labeled icon button">
Hinzufügen
<i class="checkmark icon"></i>
</div>
</div>
</div>
<script src="/assets/js/admin.js"></script>
{{end}}
</body>
</html>
{{end}}

View File

@ -10,7 +10,9 @@
<table width="100%" border="0" cellpadding="0" cellspacing="0">
<tr class="top">
<th scope="col">Name</th>
{{if eq .Mode 0}}
<th scope="col">Gemeinde</th>
{{end}}
<th scope="col">Eingezahlte KonfiCoins</th>
</tr>
<tbody id="konfis">

37
tpl/login.html Normal file
View File

@ -0,0 +1,37 @@
{{define "login"}}
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Kasino Admin</title>
<link rel="stylesheet" type="text/css" href="/assets/semantic/semantic.min.css">
<script src="/assets/js/jquery-3.1.1.min.js"></script>
<script src="/assets/semantic/semantic.min.js"></script>
</head>
<body style="background: url(/assets/bg.jpg) no-repeat center fixed">
<div class="ui middle aligned center aligned grid" style="width: 30em; margin: 37vh auto;">
<div class="column">
<h2 class="ui header">
Kasino Admin
</h2>
<form class="ui large form" id="loginform" method="post">
<div class="ui segment">
<div class="field">
<div class="ui left icon input">
<i class="lock icon"></i>
<input type="password" name="password" id="password" placeholder="Passwort" autofocus>
</div>
</div>
<div class="ui fluid large blue submit button">Login</div>
</div>
<div id="msg"></div>
</form>
</div>
</div>
<script src="/assets/js/login.js"></script>
</body>
</html>
{{end}}

View File

@ -7,6 +7,10 @@ import (
)
func update(c echo.Context) error {
//Config
SiteConf := initConfig()
//Datenbankverbindung aufbauen
db := DBinit()
@ -22,23 +26,43 @@ func update(c echo.Context) error {
id, _ := strconv.Atoi(c.FormValue("id"))
addcoins, _ := strconv.Atoi(c.FormValue("addcoins"))
//Aktuelle Coins holen
var kofi = Kofi{ID: id}
has, err := db.Get(&kofi)
checkErr(err)
if has {
newCoins := kofi.KCoins + addcoins
if SiteConf.Mode == 0 {
//Aktuelle Coins holen
var kofi= Kofi{ID: id}
has, err := db.Get(&kofi)
checkErr(err)
if has {
newCoins := kofi.KCoins + addcoins
//Updaten
kofi.KCoins = newCoins
_, err := db.Id(id).Update(kofi)
//Updaten
kofi.KCoins = newCoins
_, err := db.Id(id).Update(kofi)
if err == nil {
return c.JSON(http.StatusOK, UpdatedMessage{"success", kofi})
if err == nil {
return c.JSON(http.StatusOK, UpdatedMessageKofi{"success", kofi})
}
}
} else if SiteConf.Mode == 1{
var gemeinde = Gemeinde{ID: id}
has, err := db.Get(&gemeinde)
checkErr(err)
if has {
newCoins := gemeinde.KCoins + addcoins
// Updaten
gemeinde.KCoins = newCoins
_, err := db.ID(id).Update(gemeinde)
if err == nil {
return c.JSON(http.StatusOK, UpdatedMessageGemeinde{"success", gemeinde})
}
}
}
return c.JSON(http.StatusOK, Message{"Error."})
return c.JSON(http.StatusInternalServerError, Message{"Error."})
} else {
return c.JSON(http.StatusOK, Message{"Login first."})
return c.JSON(http.StatusForbidden, Message{"Login first."})
}
}

View File

@ -9,13 +9,24 @@ type Kofi struct {
KCoins int
}
type Gemeinde struct {
ID int `xorm:"pk autoincr"`
Name string
KCoins int
}
type Message struct {
Message string
}
type UpdatedMessage struct {
type UpdatedMessageKofi struct {
Message string
Kofi Kofi
Data Kofi
}
type UpdatedMessageGemeinde struct {
Message string
Data Gemeinde
}
//CheckError

View File

@ -1,6 +1,8 @@
# A pure Go MSSQL driver for Go's database/sql package
[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb)
[![Build status](https://ci.appveyor.com/api/projects/status/ujv21jd241h8o5s7?svg=true)](https://ci.appveyor.com/project/denisenk/go-mssqldb)
[![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb)
## Install

View File

@ -44,7 +44,9 @@ before_test:
Start-Service "SQLBrowser"
- sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;"
- sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version"
- pip install codecov
test_script:
- go test -v -cover
- go test -race -coverprofile=coverage.txt -covermode=atomic
- codecov -f coverage.txt

View File

@ -1,11 +1,9 @@
package mssql
import (
"database/sql/driver"
"encoding/binary"
"errors"
"io"
"net"
)
type packetType uint8
@ -53,19 +51,6 @@ func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
return w
}
func checkBadConn(err error) error {
if err == io.EOF {
return driver.ErrBadConn
}
switch err.(type) {
case net.Error:
return driver.ErrBadConn
default:
return err
}
}
func (rw *tdsBuffer) ResizeBuffer(packetsizei int) {
if len(rw.rbuf) != packetsizei {
newbuf := make([]byte, packetsizei)
@ -152,7 +137,7 @@ func (r *tdsBuffer) readNextPacket() error {
var err error
err = binary.Read(r.transport, binary.BigEndian, &header)
if err != nil {
return checkBadConn(err)
return err
}
offset := uint16(binary.Size(header))
if int(header.Size) > len(r.rbuf) {
@ -163,7 +148,7 @@ func (r *tdsBuffer) readNextPacket() error {
}
_, err = io.ReadFull(r.transport, r.rbuf[offset:header.Size])
if err != nil {
return checkBadConn(err)
return err
}
r.rpos = offset
r.rsize = header.Size
@ -206,7 +191,7 @@ func (r *tdsBuffer) byte() byte {
func (r *tdsBuffer) ReadFull(buf []byte) {
_, err := io.ReadFull(r, buf[:])
if err != nil {
badStreamPanic(checkBadConn(err))
badStreamPanic(err)
}
}

View File

@ -71,11 +71,18 @@ func TestBulkcopy(t *testing.T) {
conn := open(t)
defer conn.Close()
setupTable(conn, tableName)
err := setupTable(conn, tableName)
if (err != nil) {
t.Error("Setup table failed: ", err.Error())
return
}
log.Println("Preparing copyin statement")
stmt, err := conn.Prepare(CopyIn(tableName, MssqlBulkOptions{}, columns...))
for i := 0; i < 10; i++ {
log.Printf("Executing copy in statement %d time with %d values", i+1, len(values))
_, err = stmt.Exec(values...)
if err != nil {
t.Error("AddRow failed: ", err.Error())
@ -142,8 +149,8 @@ func compareValue(a interface{}, expected interface{}) bool {
return reflect.DeepEqual(expected, a)
}
}
func setupTable(conn *sql.DB, tableName string) {
func setupTable(conn *sql.DB, tableName string) (err error) {
tablesql := `CREATE TABLE ` + tableName + ` (
[id] [int] IDENTITY(1,1) NOT NULL,
[test_nvarchar] [nvarchar](50) NULL,
@ -186,9 +193,9 @@ func setupTable(conn *sql.DB, tableName string) {
[id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];`
_, err := conn.Exec(tablesql)
_, err = conn.Exec(tablesql)
if err != nil {
log.Fatal("tablesql failed:", err)
}
return
}

View File

@ -58,6 +58,38 @@ type MssqlConn struct {
transactionCtx context.Context
processQueryText bool
connectionGood bool
}
func (c *MssqlConn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
// but we return actual error instead of ErrBadConn
// this will cause connection to stay in a pool
// but next request to this connection will return ErrBadConn
// it might be possible to revise this hack after
// https://github.com/golang/go/issues/20807
// is implemented
if err == nil {
return nil
}
if err == io.EOF {
return driver.ErrBadConn
}
switch err.(type) {
case net.Error:
c.connectionGood = false
return err
case StreamError:
c.connectionGood = false
return err
default:
return err
}
}
func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
@ -67,18 +99,21 @@ func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return token.getError()
return c.checkBadConn(token.getError())
}
case error:
return token
return c.checkBadConn(token)
}
}
return nil
}
func (c *MssqlConn) Commit() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendCommitRequest(); err != nil {
return err
return c.checkBadConn(err)
}
return c.simpleProcessResp(c.transactionCtx)
}
@ -98,8 +133,11 @@ func (c *MssqlConn) sendCommitRequest() error {
}
func (c *MssqlConn) Rollback() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendRollbackRequest(); err != nil {
return err
return c.checkBadConn(err)
}
return c.simpleProcessResp(c.transactionCtx)
}
@ -122,12 +160,19 @@ func (c *MssqlConn) Begin() (driver.Tx, error) {
return c.begin(context.Background(), isolationUseCurrent)
}
func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {
err := c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, err
func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
return c.processBeginResponse(ctx)
err = c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, c.checkBadConn(err)
}
tx, err = c.processBeginResponse(ctx)
if err != nil {
return nil, c.checkBadConn(err)
}
return
}
func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
@ -183,7 +228,7 @@ func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
}
}
conn := &MssqlConn{sess, context.Background(), d.processQueryText}
conn := &MssqlConn{sess: sess, transactionCtx: context.Background(), processQueryText: d.processQueryText, connectionGood: true}
conn.sess.log = d.log
return conn, nil
}
@ -206,6 +251,9 @@ type queryNotifSub struct {
}
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query)
}
@ -315,9 +363,12 @@ func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.queryContext(context.Background(), convertOldArgs(args))
}
func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
if err := s.sendQuery(args); err != nil {
return nil, err
func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
}
return s.processQueryResponse(ctx)
}
@ -343,13 +394,13 @@ loop:
break loop
case doneStruct:
if token.isError() {
return nil, token.getError()
return nil, s.c.checkBadConn(token.getError())
}
case error:
return nil, token
return nil, s.c.checkBadConn(token)
}
}
res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel}
res = &MssqlRows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
return
}
@ -357,11 +408,17 @@ func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), convertOldArgs(args))
}
func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if err := s.sendQuery(args); err != nil {
return nil, err
func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
return s.processExec(ctx)
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
}
if res, err = s.processExec(ctx); err != nil {
return nil, s.c.checkBadConn(err)
}
return
}
func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
@ -389,7 +446,7 @@ func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err err
}
type MssqlRows struct {
sess *tdsSession
stmt *MssqlStmt
cols []columnStruct
tokchan chan tokenStruct
@ -415,6 +472,9 @@ func (rc *MssqlRows) Columns() (res []string) {
}
func (rc *MssqlRows) Next(dest []driver.Value) error {
if !rc.stmt.c.connectionGood {
return driver.ErrBadConn
}
if rc.nextCols != nil {
return io.EOF
}
@ -430,10 +490,10 @@ func (rc *MssqlRows) Next(dest []driver.Value) error {
return nil
case doneStruct:
if tokdata.isError() {
return tokdata.getError()
return rc.stmt.c.checkBadConn(tokdata.getError())
}
case error:
return tokdata
return rc.stmt.c.checkBadConn(tokdata)
}
}
return io.EOF

View File

@ -343,7 +343,6 @@ func TestExec(t *testing.T) {
}
func TestShortTimeout(t *testing.T) {
t.Skip("TODO: fix this test")
if testing.Short() {
t.Skip("short")
}
@ -841,7 +840,6 @@ func TestConnectionClosing(t *testing.T) {
}
func TestBeginTranError(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String())
@ -878,10 +876,13 @@ func TestBeginTranError(t *testing.T) {
case driver.ErrBadConn:
t.Error("processBeginResponse should fail with error different from ErrBadConn but it did")
}
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
}
func TestCommitTranError(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String())
@ -919,6 +920,10 @@ func TestCommitTranError(t *testing.T) {
t.Error("simpleProcessResp should fail with error different from ErrBadConn but it did")
}
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
// reopen connection
conn, err = drv.open(makeConnStr(t).String())
defer conn.Close()
@ -936,7 +941,6 @@ func TestCommitTranError(t *testing.T) {
}
func TestRollbackTranError(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String())
@ -974,6 +978,10 @@ func TestRollbackTranError(t *testing.T) {
t.Error("simpleProcessResp should fail with error different from ErrBadConn but it did")
}
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
// reopen connection
conn, err = drv.open(makeConnStr(t).String())
defer conn.Close()
@ -1031,7 +1039,6 @@ func TestSendQueryErrors(t *testing.T) {
}
func TestProcessQueryErrors(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String())
@ -1056,6 +1063,10 @@ func TestProcessQueryErrors(t *testing.T) {
if err == driver.ErrBadConn {
t.Error("processQueryResponse expected to fail with error other than ErrBadConn but it failed with it")
}
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
}
func TestSendExecErrors(t *testing.T) {

View File

@ -198,36 +198,6 @@ func TestConnect(t *testing.T) {
defer conn.Close()
}
func TestBadConnect(t *testing.T) {
t.Skip("TODO: fix this test")
var badDSNs []string
if parsed, err := url.Parse(os.Getenv("SQLSERVER_DSN")); err == nil {
parsed.User = url.UserPassword("baduser", "badpwd")
badDSNs = append(badDSNs, parsed.String())
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("INSTANCE")) > 0 {
badDSNs = append(badDSNs,
fmt.Sprintf(
"Server=%s\\%s;User ID=baduser;Password=badpwd",
os.Getenv("HOST"), os.Getenv("INSTANCE"),
),
)
}
SetLogger(testLogger{t})
for _, badDsn := range badDSNs {
conn, err := sql.Open("mssql", badDsn)
if err != nil {
t.Error("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Error("Ping should fail for connection: ", badDsn)
}
}
}
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
stmt, err := conn.Prepare("select 1 as a")
if err != nil {

View File

@ -923,6 +923,10 @@ func makeGoLangScanType(ti typeInfo) reflect.Type {
default:
panic("invalid size of MONEYN")
}
case typeDateTim4:
return reflect.TypeOf(time.Time{})
case typeDateTime:
return reflect.TypeOf(time.Time{})
case typeDateTimeN:
switch ti.Size {
case 4:

View File

@ -1,11 +1,10 @@
sudo: false
language: go
go:
- 1.4
- 1.5
- 1.6
- 1.7
- 1.8
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- master
script:

63
vendor/github.com/go-ini/ini/ini.go generated vendored
View File

@ -24,10 +24,8 @@ import (
"os"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
)
const (
@ -37,7 +35,7 @@ const (
// Maximum allowed depth when recursively substituing variable names.
_DEPTH_VALUES = 99
_VERSION = "1.28.1"
_VERSION = "1.28.2"
)
// Version returns current package version literal.
@ -398,10 +396,7 @@ func (f *File) Append(source interface{}, others ...interface{}) error {
return f.Reload()
}
// WriteToIndent writes content into io.Writer with given indention.
// If PrettyFormat has been set to be true,
// it will align "=" sign with spaces under each section.
func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
func (f *File) writeToBuffer(indent string) (*bytes.Buffer, error) {
equalSign := "="
if PrettyFormat {
equalSign = " = "
@ -415,14 +410,14 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
if sec.Comment[0] != '#' && sec.Comment[0] != ';' {
sec.Comment = "; " + sec.Comment
}
if _, err = buf.WriteString(sec.Comment + LineBreak); err != nil {
return 0, err
if _, err := buf.WriteString(sec.Comment + LineBreak); err != nil {
return nil, err
}
}
if i > 0 || DefaultHeader {
if _, err = buf.WriteString("[" + sname + "]" + LineBreak); err != nil {
return 0, err
if _, err := buf.WriteString("[" + sname + "]" + LineBreak); err != nil {
return nil, err
}
} else {
// Write nothing if default section is empty
@ -432,8 +427,8 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
}
if sec.isRawSection {
if _, err = buf.WriteString(sec.rawBody); err != nil {
return 0, err
if _, err := buf.WriteString(sec.rawBody); err != nil {
return nil, err
}
continue
}
@ -469,8 +464,8 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
if key.Comment[0] != '#' && key.Comment[0] != ';' {
key.Comment = "; " + key.Comment
}
if _, err = buf.WriteString(key.Comment + LineBreak); err != nil {
return 0, err
if _, err := buf.WriteString(key.Comment + LineBreak); err != nil {
return nil, err
}
}
@ -488,8 +483,8 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
}
for _, val := range key.ValueWithShadows() {
if _, err = buf.WriteString(kname); err != nil {
return 0, err
if _, err := buf.WriteString(kname); err != nil {
return nil, err
}
if key.isBooleanType {
@ -510,20 +505,31 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
} else if !f.options.IgnoreInlineComment && strings.ContainsAny(val, "#;") {
val = "`" + val + "`"
}
if _, err = buf.WriteString(equalSign + val + LineBreak); err != nil {
return 0, err
if _, err := buf.WriteString(equalSign + val + LineBreak); err != nil {
return nil, err
}
}
}
if PrettySection {
// Put a line between sections
if _, err = buf.WriteString(LineBreak); err != nil {
return 0, err
if _, err := buf.WriteString(LineBreak); err != nil {
return nil, err
}
}
}
return buf, nil
}
// WriteToIndent writes content into io.Writer with given indention.
// If PrettyFormat has been set to be true,
// it will align "=" sign with spaces under each section.
func (f *File) WriteToIndent(w io.Writer, indent string) (int64, error) {
buf, err := f.writeToBuffer(indent)
if err != nil {
return 0, err
}
return buf.WriteTo(w)
}
@ -536,23 +542,12 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
func (f *File) SaveToIndent(filename, indent string) error {
// Note: Because we are truncating with os.Create,
// so it's safer to save to a temporary file location and rename afte done.
tmpPath := filename + "." + strconv.Itoa(time.Now().Nanosecond()) + ".tmp"
defer os.Remove(tmpPath)
fw, err := os.Create(tmpPath)
buf, err := f.writeToBuffer(indent)
if err != nil {
return err
}
if _, err = f.WriteToIndent(fw, indent); err != nil {
fw.Close()
return err
}
fw.Close()
// Remove old file and rename the new one.
os.Remove(filename)
return os.Rename(tmpPath, filename)
return ioutil.WriteFile(filename, buf.Bytes(), 0666)
}
// SaveTo writes content to file system.

View File

@ -13,12 +13,13 @@ const (
ONLYFROMDB
)
// database column
// Column defines database column
type Column struct {
Name string
TableName string
FieldName string
SQLType SQLType
IsJSON bool
Length int
Length2 int
Nullable bool

View File

@ -247,6 +247,18 @@ type Row struct {
err error // deferred error for easy chaining
}
// ErrorRow return an error row
func ErrorRow(err error) *Row {
return &Row{
err: err,
}
}
// NewRow from rows
func NewRow(rows *Rows, err error) *Row {
return &Row{rows, err}
}
func (row *Row) Columns() ([]string, error) {
if row.err != nil {
return nil, row.err

View File

@ -6,10 +6,6 @@ Xorm is a simple and powerful ORM for Go.
[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm)
[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
# Notice
The last master version is not backwards compatible. You should use `engine.ShowSQL()` and `engine.Logger().SetLevel()` instead of `engine.ShowSQL = `, `engine.ShowInfo = ` and so on.
# Features
* Struct <-> Table Mapping Support
@ -38,7 +34,7 @@ Drivers for Go's sql package which currently support database/sql includes:
* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/tree/master/godrv)
* Postgres: [github.com/lib/pq](https://github.com/lib/pq)
@ -52,6 +48,14 @@ Drivers for Go's sql package which currently support database/sql includes:
# Changelog
* **v0.6.3**
* merge tests to main project
* add `Exist` function
* add `SumInt` function
* Mysql now support read and create column comment.
* fix time related bugs.
* fix some other bugs.
* **v0.6.2**
* refactor tag parse methods
* add Scan features to Get
@ -64,22 +68,6 @@ methods can use `builder.Cond` as parameter
* add Sum, SumInt, SumInt64 and NotIn methods
* some bugs fixed
* **v0.5.0**
* logging interface changed
* some bugs fixed
* **v0.4.5**
* many bugs fixed
* extends support unlimited deepth
* Delete Limit support
* **v0.4.4**
* ql database expriment support
* tidb database expriment support
* sql.NullString and etc. field support
* select ForUpdate support
* many bugs fixed
[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16)
# Installation
@ -126,7 +114,7 @@ results, err := engine.Query("select * from user")
results, err := engine.QueryString("select * from user")
```
* `Execute` runs a SQL string, it returns `affetcted` and `error`
* `Execute` runs a SQL string, it returns `affected` and `error`
```Go
affected, err := engine.Exec("update user set age = ? where name = ?", age, name)
@ -168,6 +156,25 @@ has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ?
```
* Check if one record exist on table
```Go
has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{
Name: "test1",
})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
```
* Query multiple records from database, also you can use join and extends
```Go
@ -260,6 +267,14 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
# Cases
* [studygolang](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
* [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana)
* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader)
* [Wego](http://github.com/go-tango/wego)
@ -268,8 +283,6 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [Gorevel](http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel)
* [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker)

View File

@ -8,10 +8,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm)
[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
# 注意
最新的版本有不兼容的更新,您必须使用 `engine.ShowSQL()``engine.Logger().SetLevel()` 来替代 `engine.ShowSQL = `, `engine.ShowInfo = ` 等等。
## 特性
* 支持Struct和数据库表之间的灵活映射并支持自动同步
@ -56,6 +52,15 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 更新日志
* **v0.6.3**
* 合并单元测试到主工程
* 新增`Exist`方法
* 新增`SumInt`方法
* Mysql新增读取和创建字段注释支持
* 新增`SetConnMaxLifetime`方法
* 修正了时间相关的Bug
* 修复了一些其它Bug
* **v0.6.2**
* 重构Tag解析方式
* Get方法新增类似Scan的特性
@ -72,18 +77,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* logging接口进行不兼容改变
* Bug修正
* **v0.4.5**
* bug修正
* extends 支持无限级
* Delete Limit 支持
* **v0.4.4**
* Tidb 数据库支持
* QL 试验性支持
* sql.NullString支持
* ForUpdate 支持
* bug修正
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## 安装
@ -172,6 +165,25 @@ has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ?
```
* 检测记录是否存在
```Go
has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{
Name: "test1",
})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
```
* 查询多条记录当然可以使用Join和extends来组合使用
```Go
@ -263,6 +275,14 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
# 案例
* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
* [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana)
* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader)
* [Wego](http://github.com/go-tango/wego)
@ -271,8 +291,6 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker)
* [Gobuild.io](http://gobuild.io) - [github.com/shxsun/gobuild](http://github.com/shxsun/gobuild)

View File

@ -21,7 +21,7 @@ database:
test:
override:
# './...' is a relative pattern which means all subdirectories
- go test -v -race -db="sqlite3::mysql::postgres" -conn_str="./test.db::root:@/xorm_test::dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic
- go test -v -race -db="sqlite3::mysql::mymysql::postgres" -conn_str="./test.db::root:@/xorm_test::xorm_test/root/::dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh

View File

@ -90,7 +90,7 @@ another is Rows
5. Update one or more records
affected, err := engine.Id(...).Update(&user)
affected, err := engine.ID(...).Update(&user)
// UPDATE user SET ...
6. Delete one or more records, Delete MUST has condition

View File

@ -169,7 +169,7 @@ func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
}
// SqlType will be depracated, please use SQLType instead
// SqlType will be deprecated, please use SQLType instead
//
// Deprecated: use SQLType instead
func (engine *Engine) SqlType(c *core.Column) string {
@ -205,14 +205,14 @@ func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
// you can use NoCache()
func (engine *Engine) NoCache() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoCache()
}
// NoCascade If you do not want to auto cascade load object
func (engine *Engine) NoCascade() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoCascade()
}
@ -245,7 +245,7 @@ func (engine *Engine) Dialect() core.Dialect {
// NewSession New a session
func (engine *Engine) NewSession() *Session {
session := &Session{Engine: engine}
session := &Session{engine: engine}
session.Init()
return session
}
@ -259,7 +259,6 @@ func (engine *Engine) Close() error {
func (engine *Engine) Ping() error {
session := engine.NewSession()
defer session.Close()
engine.logger.Infof("PING DATABASE %v", engine.DriverName())
return session.Ping()
}
@ -274,36 +273,6 @@ func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) {
}
}
func (engine *Engine) logSQLQueryTime(sqlStr string, args []interface{}, executionBlock func() (*core.Stmt, *core.Rows, error)) (*core.Stmt, *core.Rows, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
stmt, res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[SQL] %s %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
return stmt, res, err
}
return executionBlock()
}
func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[sql] %s [args] %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[sql] %s - took: %v", sqlStr, execDuration)
}
return res, err
}
return executionBlock()
}
// Sql provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.
//
@ -320,7 +289,7 @@ func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
// This code will execute "select * from user" and set the records to users
func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.SQL(query, args...)
}
@ -329,14 +298,14 @@ func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
// invoked. Call NoAutoTime if you dont' want to fill automatically.
func (engine *Engine) NoAutoTime() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoAutoTime()
}
// NoAutoCondition disable auto generate Where condition from bean or not
func (engine *Engine) NoAutoCondition(no ...bool) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.NoAutoCondition(no...)
}
@ -570,56 +539,56 @@ func (engine *Engine) tbName(v reflect.Value) string {
// Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Cascade(trueOrFalse...)
}
// Where method provide a condition query
func (engine *Engine) Where(query interface{}, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Where(query, args...)
}
// Id will be depracated, please use ID instead
// Id will be deprecated, please use ID instead
func (engine *Engine) Id(id interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Id(id)
}
// ID method provoide a condition as (id) = ?
func (engine *Engine) ID(id interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.ID(id)
}
// Before apply before Processor, affected bean is passed to closure arg
func (engine *Engine) Before(closures func(interface{})) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Before(closures)
}
// After apply after insert Processor, affected bean is passed to closure arg
func (engine *Engine) After(closures func(interface{})) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.After(closures)
}
// Charset set charset when create table, only support mysql now
func (engine *Engine) Charset(charset string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Charset(charset)
}
// StoreEngine set store engine when create table, only support mysql now
func (engine *Engine) StoreEngine(storeEngine string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.StoreEngine(storeEngine)
}
@ -628,35 +597,35 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session {
// but distinct will not provide id
func (engine *Engine) Distinct(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Distinct(columns...)
}
// Select customerize your select columns or contents
func (engine *Engine) Select(str string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Select(str)
}
// Cols only use the parameters as select or update columns
func (engine *Engine) Cols(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Cols(columns...)
}
// AllCols indicates that all columns should be use
func (engine *Engine) AllCols() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.AllCols()
}
// MustCols specify some columns must use even if they are empty
func (engine *Engine) MustCols(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.MustCols(columns...)
}
@ -667,77 +636,84 @@ func (engine *Engine) MustCols(columns ...string) *Session {
// it will use parameters's columns
func (engine *Engine) UseBool(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.UseBool(columns...)
}
// Omit only not use the parameters as select or update columns
func (engine *Engine) Omit(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Omit(columns...)
}
// Nullable set null when column is zero-value and nullable for update
func (engine *Engine) Nullable(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Nullable(columns...)
}
// In will generate "column IN (?, ?)"
func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.In(column, args...)
}
// NotIn will generate "column NOT IN (?, ?)"
func (engine *Engine) NotIn(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.NotIn(column, args...)
}
// Incr provides a update string like "column = column + ?"
func (engine *Engine) Incr(column string, arg ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Incr(column, arg...)
}
// Decr provides a update string like "column = column - ?"
func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Decr(column, arg...)
}
// SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.SetExpr(column, expression)
}
// Table temporarily change the Get, Find, Update's table
func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Table(tableNameOrBean)
}
// Alias set the table alias
func (engine *Engine) Alias(alias string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Alias(alias)
}
// Limit will generate "LIMIT start, limit"
func (engine *Engine) Limit(limit int, start ...int) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Limit(limit, start...)
}
// Desc will generate "ORDER BY column1 DESC, column2 DESC"
func (engine *Engine) Desc(colNames ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Desc(colNames...)
}
@ -749,35 +725,35 @@ func (engine *Engine) Desc(colNames ...string) *Session {
//
func (engine *Engine) Asc(colNames ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Asc(colNames...)
}
// OrderBy will generate "ORDER BY order"
func (engine *Engine) OrderBy(order string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.OrderBy(order)
}
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Join(joinOperator, tablename, condition, args...)
}
// GroupBy generate group by statement
func (engine *Engine) GroupBy(keys string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.GroupBy(keys)
}
// Having generate having statement
func (engine *Engine) Having(conditions string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Having(conditions)
}
@ -1208,6 +1184,9 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
for _, bean := range beans {
v := rValue(bean)
tableName := engine.tbName(v)
@ -1216,14 +1195,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
s := engine.NewSession()
defer s.Close()
isExist, err := s.Table(bean).isTableExist(tableName)
isExist, err := session.Table(bean).isTableExist(tableName)
if err != nil {
return err
}
if !isExist {
err = engine.CreateTables(bean)
err = session.createTable(bean)
if err != nil {
return err
}
@ -1234,11 +1211,11 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}*/
var isEmpty bool
if isEmpty {
err = engine.DropTables(bean)
err = session.dropTable(bean)
if err != nil {
return err
}
err = engine.CreateTables(bean)
err = session.createTable(bean)
if err != nil {
return err
}
@ -1249,9 +1226,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
session := engine.NewSession()
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err
}
err = session.addColumn(col.Name)
@ -1262,9 +1237,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
for name, index := range table.Indexes {
session := engine.NewSession()
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err
}
if index.Type == core.UniqueType {
@ -1273,9 +1246,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
session := engine.NewSession()
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err
}
@ -1290,9 +1261,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
session := engine.NewSession()
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err
}
@ -1328,7 +1297,7 @@ func (engine *Engine) CreateTables(beans ...interface{}) error {
}
for _, bean := range beans {
err = session.CreateTable(bean)
err = session.createTable(bean)
if err != nil {
session.Rollback()
return err
@ -1348,7 +1317,7 @@ func (engine *Engine) DropTables(beans ...interface{}) error {
}
for _, bean := range beans {
err = session.DropTable(bean)
err = session.dropTable(bean)
if err != nil {
session.Rollback()
return err
@ -1385,6 +1354,13 @@ func (engine *Engine) QueryString(sqlStr string, args ...interface{}) ([]map[str
return session.QueryString(sqlStr, args...)
}
// QueryInterface runs a raw sql and return records as []map[string]interface{}
func (engine *Engine) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) {
session := engine.NewSession()
defer session.Close()
return session.QueryInterface(sqlStr, args...)
}
// Insert one or more records
func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session := engine.NewSession()
@ -1426,6 +1402,13 @@ func (engine *Engine) Get(bean interface{}) (bool, error) {
return session.Get(bean)
}
// Exist returns true if the record exist otherwise return false
func (engine *Engine) Exist(bean ...interface{}) (bool, error) {
session := engine.NewSession()
defer session.Close()
return session.Exist(bean...)
}
// Find retrieve records from table, condiBeans's non-empty fields
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct
@ -1451,10 +1434,10 @@ func (engine *Engine) Rows(bean interface{}) (*Rows, error) {
}
// Count counts the records. bean's non-empty fields are conditions.
func (engine *Engine) Count(bean interface{}) (int64, error) {
func (engine *Engine) Count(bean ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.Count(bean)
return session.Count(bean...)
}
// Sum sum the records by some column. bean's non-empty fields are conditions.
@ -1580,7 +1563,7 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}
// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()
session.IsAutoClose = true
session.isAutoClose = true
return session.Unscoped()
}

View File

@ -67,7 +67,7 @@ func main() {
fmt.Println("users3:", users3)
user4 := new(User)
has, err := Orm.Id(1).Get(user4)
has, err := Orm.ID(1).Get(user4)
if err != nil {
fmt.Println(err)
return
@ -76,7 +76,7 @@ func main() {
fmt.Println("user4:", has, user4)
user4.Name = "xiaolunwen"
_, err = Orm.Id(1).Update(user4)
_, err = Orm.ID(1).Update(user4)
if err != nil {
fmt.Println(err)
return
@ -84,14 +84,14 @@ func main() {
fmt.Println("user4:", user4)
user5 := new(User)
has, err = Orm.Id(1).Get(user5)
has, err = Orm.ID(1).Get(user5)
if err != nil {
fmt.Println(err)
return
}
fmt.Println("user5:", has, user5)
_, err = Orm.Id(1).Delete(new(User))
_, err = Orm.ID(1).Delete(new(User))
if err != nil {
fmt.Println(err)
return
@ -99,7 +99,7 @@ func main() {
for {
user6 := new(User)
has, err = Orm.Id(1).Get(user6)
has, err = Orm.ID(1).Get(user6)
if err != nil {
fmt.Println(err)
return

View File

@ -55,7 +55,7 @@ func test(engine *xorm.Engine) {
} else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 {
//_, err = engine.Id(1).Delete(u)
//_, err = engine.ID(1).Delete(u)
_, err = engine.Delete(u)
}
if err != nil {

View File

@ -51,7 +51,7 @@ func main() {
}
info := LoginInfo{}
_, err = orm.Id(1).Get(&info)
_, err = orm.ID(1).Get(&info)
if err != nil {
fmt.Println(err)
return

View File

@ -59,7 +59,7 @@ func test(engine *xorm.Engine) {
} else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 {
_, err = engine.Id(1).Delete(u)
_, err = engine.ID(1).Delete(u)
}
if err != nil {
fmt.Println(err)

View File

@ -62,7 +62,7 @@ func test(engine *xorm.Engine) {
} else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 {
_, err = engine.Id(1).Delete(u)
_, err = engine.ID(1).Delete(u)
}
if err != nil {
fmt.Println(err)

View File

@ -48,7 +48,7 @@ func main() {
}
info := LoginInfo{}
_, err = orm.Id(1).Get(&info)
_, err = orm.ID(1).Get(&info)
if err != nil {
fmt.Println(err)
return

View File

@ -358,7 +358,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
for _, col := range table.Columns() {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
@ -397,32 +397,32 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
} else if _, ok := session.Statement.incrColumns[col.Name]; ok {
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.Statement.decrColumns[col.Name]; ok {
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.Statement.nullableMap, col); ok {
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.Engine.NowTime2(col.SQLType.Name)
val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val)
var colName = col.Name
@ -430,7 +430,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
@ -441,7 +441,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
}
if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}

View File

@ -173,7 +173,7 @@ func TestProcessors(t *testing.T) {
}
p2 := &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2)
_, err = testEngine.ID(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -308,7 +308,7 @@ func TestProcessors(t *testing.T) {
}
p2 = &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2)
_, err = testEngine.ID(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -402,7 +402,7 @@ func TestProcessors(t *testing.T) {
for _, elem := range pslice {
p = &ProcessorsStruct{}
_, err = testEngine.Id(elem.Id).Get(p)
_, err = testEngine.ID(elem.Id).Get(p)
if err != nil {
t.Error(err)
panic(err)
@ -508,7 +508,7 @@ func TestProcessorsTx(t *testing.T) {
}
session.Close()
p2 := &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2)
_, err = testEngine.ID(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -569,7 +569,7 @@ func TestProcessorsTx(t *testing.T) {
}
session.Close()
p2 = &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2)
_, err = testEngine.ID(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -616,7 +616,7 @@ func TestProcessorsTx(t *testing.T) {
p = p2 // reset
_, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
_, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil {
t.Error(err)
panic(err)
@ -656,7 +656,7 @@ func TestProcessorsTx(t *testing.T) {
session.Close()
p2 = &ProcessorsStruct{}
_, err = testEngine.Id(insertedId).Get(p2)
_, err = testEngine.ID(insertedId).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -729,7 +729,7 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{}
_, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
_, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil {
t.Error(err)
panic(err)
@ -767,7 +767,7 @@ func TestProcessorsTx(t *testing.T) {
}
session.Close()
p2 = &ProcessorsStruct{}
_, err = testEngine.Id(insertedId).Get(p2)
_, err = testEngine.ID(insertedId).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -813,7 +813,7 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{} // reset
_, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
_, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil {
t.Error(err)
panic(err)
@ -852,7 +852,7 @@ func TestProcessorsTx(t *testing.T) {
session.Close()
p2 = &ProcessorsStruct{}
_, err = testEngine.Id(insertedId).Get(p2)
_, err = testEngine.ID(insertedId).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -882,7 +882,7 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{}
_, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
_, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil {
t.Error(err)
panic(err)

View File

@ -17,7 +17,6 @@ type Rows struct {
NoTypeCheck bool
session *Session
stmt *core.Stmt
rows *core.Rows
fields []string
beanType reflect.Type
@ -29,56 +28,33 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
defer rows.session.resetStatement()
var sqlStr string
var args []interface{}
var err error
if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil {
if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
return nil, err
}
if len(session.Statement.TableName()) <= 0 {
if len(session.statement.TableName()) <= 0 {
return nil, ErrTableNotFound
}
if rows.session.Statement.RawSQL == "" {
sqlStr, args, err = rows.session.Statement.genGetSQL(bean)
if rows.session.statement.RawSQL == "" {
sqlStr, args, err = rows.session.statement.genGetSQL(bean)
if err != nil {
return nil, err
}
} else {
sqlStr = rows.session.Statement.RawSQL
args = rows.session.Statement.RawParams
sqlStr = rows.session.statement.RawSQL
args = rows.session.statement.RawParams
}
for _, filter := range rows.session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable)
}
rows.session.saveLastSQL(sqlStr, args...)
if rows.session.prepareStmt {
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.rows, err = rows.stmt.Query(args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
} else {
rows.rows, err = rows.session.DB().Query(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.rows, err = rows.session.queryRows(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.fields, err = rows.rows.Columns()
@ -119,7 +95,7 @@ func (rows *Rows) Scan(bean interface{}) error {
}
dataStruct := rValue(bean)
if err := rows.session.Statement.setRefValue(dataStruct); err != nil {
if err := rows.session.statement.setRefValue(dataStruct); err != nil {
return err
}
@ -128,13 +104,13 @@ func (rows *Rows) Scan(bean interface{}) error {
return err
}
_, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.Statement.RefTable)
_, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.statement.RefTable)
return err
}
// Close session if session.IsAutoClose is true, and claimed any opened resources
func (rows *Rows) Close() error {
if rows.session.IsAutoClose {
if rows.session.isAutoClose {
defer rows.session.Close()
}
@ -142,17 +118,10 @@ func (rows *Rows) Close() error {
if rows.rows != nil {
rows.lastError = rows.rows.Close()
if rows.lastError != nil {
defer rows.stmt.Close()
return rows.lastError
}
}
if rows.stmt != nil {
rows.lastError = rows.stmt.Close()
}
} else {
if rows.stmt != nil {
defer rows.stmt.Close()
}
if rows.rows != nil {
defer rows.rows.Close()
}

View File

@ -38,4 +38,31 @@ func TestRows(t *testing.T) {
cnt++
}
assert.EqualValues(t, 1, cnt)
sess := testEngine.NewSession()
defer sess.Close()
rows1, err := sess.Prepare().Rows(new(UserRows))
assert.NoError(t, err)
defer rows1.Close()
cnt = 0
for rows1.Next() {
err = rows1.Scan(user)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 1, cnt)
rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows))
assert.NoError(t, err)
defer rows2.Close()
cnt = 0
for rows2.Next() {
err = rows2.Scan(user)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 1, cnt)
}

View File

@ -21,16 +21,16 @@ import (
// kind of database operations.
type Session struct {
db *core.DB
Engine *Engine
Tx *core.Tx
Statement Statement
IsAutoCommit bool
IsCommitedOrRollbacked bool
IsAutoClose bool
engine *Engine
tx *core.Tx
statement Statement
isAutoCommit bool
isCommitedOrRollbacked bool
isAutoClose bool
// Automatically reset the statement after operations that execute a SQL
// query such as Count(), Find(), Get(), ...
AutoResetStatement bool
autoResetStatement bool
// !nashtsai! storing these beans due to yet committed tx
afterInsertBeans map[interface{}]*[]func(interface{})
@ -60,12 +60,12 @@ func (session *Session) Clone() *Session {
// Init reset the session as the init status.
func (session *Session) Init() {
session.Statement.Init()
session.Statement.Engine = session.Engine
session.IsAutoCommit = true
session.IsCommitedOrRollbacked = false
session.IsAutoClose = false
session.AutoResetStatement = true
session.statement.Init()
session.statement.Engine = session.engine
session.isAutoCommit = true
session.isCommitedOrRollbacked = false
session.isAutoClose = false
session.autoResetStatement = true
session.prepareStmt = false
// !nashtsai! is lazy init better?
@ -88,19 +88,23 @@ func (session *Session) Close() {
if session.db != nil {
// When Close be called, if session is a transaction and do not call
// Commit or Rollback, then call Rollback.
if session.Tx != nil && !session.IsCommitedOrRollbacked {
if session.tx != nil && !session.isCommitedOrRollbacked {
session.Rollback()
}
session.Tx = nil
session.tx = nil
session.stmtCache = nil
session.Init()
session.db = nil
}
}
// IsClosed returns if session is closed
func (session *Session) IsClosed() bool {
return session.db == nil
}
func (session *Session) resetStatement() {
if session.AutoResetStatement {
session.Statement.Init()
if session.autoResetStatement {
session.statement.Init()
}
}
@ -128,75 +132,75 @@ func (session *Session) After(closures func(interface{})) *Session {
// Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableNameOrBean)
session.statement.Table(tableNameOrBean)
return session
}
// Alias set the table alias
func (session *Session) Alias(alias string) *Session {
session.Statement.Alias(alias)
session.statement.Alias(alias)
return session
}
// NoCascade indicate that no cascade load child object
func (session *Session) NoCascade() *Session {
session.Statement.UseCascade = false
session.statement.UseCascade = false
return session
}
// ForUpdate Set Read/Write locking for UPDATE
func (session *Session) ForUpdate() *Session {
session.Statement.IsForUpdate = true
session.statement.IsForUpdate = true
return session
}
// NoAutoCondition disable generate SQL condition from beans
func (session *Session) NoAutoCondition(no ...bool) *Session {
session.Statement.NoAutoCondition(no...)
session.statement.NoAutoCondition(no...)
return session
}
// Limit provide limit and offset query condition
func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...)
session.statement.Limit(limit, start...)
return session
}
// OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session {
session.Statement.OrderBy(order)
session.statement.OrderBy(order)
return session
}
// Desc provide desc order by query condition, the input parameters are columns.
func (session *Session) Desc(colNames ...string) *Session {
session.Statement.Desc(colNames...)
session.statement.Desc(colNames...)
return session
}
// Asc provide asc order by query condition, the input parameters are columns.
func (session *Session) Asc(colNames ...string) *Session {
session.Statement.Asc(colNames...)
session.statement.Asc(colNames...)
return session
}
// StoreEngine is only avialble mysql dialect currently
func (session *Session) StoreEngine(storeEngine string) *Session {
session.Statement.StoreEngine = storeEngine
session.statement.StoreEngine = storeEngine
return session
}
// Charset is only avialble mysql dialect currently
func (session *Session) Charset(charset string) *Session {
session.Statement.Charset = charset
session.statement.Charset = charset
return session
}
// Cascade indicates if loading sub Struct
func (session *Session) Cascade(trueOrFalse ...bool) *Session {
if len(trueOrFalse) >= 1 {
session.Statement.UseCascade = trueOrFalse[0]
session.statement.UseCascade = trueOrFalse[0]
}
return session
}
@ -204,32 +208,32 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
// NoCache ask this session do not retrieve data from cache system and
// get data from database directly.
func (session *Session) NoCache() *Session {
session.Statement.UseCache = false
session.statement.UseCache = false
return session
}
// Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session.Statement.Join(joinOperator, tablename, condition, args...)
session.statement.Join(joinOperator, tablename, condition, args...)
return session
}
// GroupBy Generate Group By statement
func (session *Session) GroupBy(keys string) *Session {
session.Statement.GroupBy(keys)
session.statement.GroupBy(keys)
return session
}
// Having Generate Having statement
func (session *Session) Having(conditions string) *Session {
session.Statement.Having(conditions)
session.statement.Having(conditions)
return session
}
// DB db return the wrapper of sql.DB
func (session *Session) DB() *core.DB {
if session.db == nil {
session.db = session.Engine.db
session.db = session.engine.db
session.stmtCache = make(map[uint32]*core.Stmt, 0)
}
return session.db
@ -242,13 +246,13 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) {
}
func (session *Session) canCache() bool {
if session.Statement.RefTable == nil ||
session.Statement.JoinStr != "" ||
session.Statement.RawSQL != "" ||
!session.Statement.UseCache ||
session.Statement.IsForUpdate ||
session.Tx != nil ||
len(session.Statement.selectStr) > 0 {
if session.statement.RefTable == nil ||
session.statement.JoinStr != "" ||
session.statement.RawSQL != "" ||
!session.statement.UseCache ||
session.statement.IsForUpdate ||
session.tx != nil ||
len(session.statement.selectStr) > 0 {
return false
}
return true
@ -272,18 +276,18 @@ func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value {
var col *core.Column
if col = table.GetColumnIdx(key, idx); col == nil {
//session.Engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
//session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
return nil
}
fieldValue, err := col.ValueOfV(dataStruct)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return nil
}
if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
return nil
}
return fieldValue
@ -528,7 +532,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
}
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
dbTZ := session.Engine.DatabaseTZ
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
@ -541,25 +545,25 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.Engine.TZLocation)
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.Engine.TZLocation)
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.Engine.logger.Error("byte2Time error:", err.Error())
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@ -568,7 +572,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.Engine.logger.Error("byte2Time error:", err.Error())
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@ -581,7 +585,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.Engine.logger.Error("sql.Sanner error:", err.Error())
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
@ -606,8 +610,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
fieldValue.Set(x.Elem())
}
}
} else if session.Statement.UseCascade {
table, err := session.Engine.autoMapType(*fieldValue)
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return nil, err
}
@ -627,9 +631,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
@ -773,19 +775,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
return pk, nil
}
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.Engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
session.saveLastSQL(*sqlStr, paramStr...)
}
// saveLastSQL stores executed query information
func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql
session.lastSQLArgs = args
session.Engine.logSQL(sql, args...)
session.engine.logSQL(sql, args...)
}
// LastSQL returns last query information
@ -795,8 +789,8 @@ func (session *Session) LastSQL() (string, []interface{}) {
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.Statement.AltTableName) > 0 {
return session.Statement.AltTableName
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
@ -804,6 +798,6 @@ func (session *Session) tbNameNoSchema(table *core.Table) string {
// Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session {
session.Statement.Unscoped()
session.statement.Unscoped()
return session
}

View File

@ -6,43 +6,43 @@ package xorm
// Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.Statement.Incr(column, arg...)
session.statement.Incr(column, arg...)
return session
}
// Decr provides a query string like "count = count - 1"
func (session *Session) Decr(column string, arg ...interface{}) *Session {
session.Statement.Decr(column, arg...)
session.statement.Decr(column, arg...)
return session
}
// SetExpr provides a query string like "column = {expression}"
func (session *Session) SetExpr(column string, expression string) *Session {
session.Statement.SetExpr(column, expression)
session.statement.SetExpr(column, expression)
return session
}
// Select provides some columns to special
func (session *Session) Select(str string) *Session {
session.Statement.Select(str)
session.statement.Select(str)
return session
}
// Cols provides some columns to special
func (session *Session) Cols(columns ...string) *Session {
session.Statement.Cols(columns...)
session.statement.Cols(columns...)
return session
}
// AllCols ask all columns
func (session *Session) AllCols() *Session {
session.Statement.AllCols()
session.statement.AllCols()
return session
}
// MustCols specify some columns must use even if they are empty
func (session *Session) MustCols(columns ...string) *Session {
session.Statement.MustCols(columns...)
session.statement.MustCols(columns...)
return session
}
@ -52,7 +52,7 @@ func (session *Session) MustCols(columns ...string) *Session {
// If no parameters, it will use all the bool field of struct, or
// it will use parameters's columns
func (session *Session) UseBool(columns ...string) *Session {
session.Statement.UseBool(columns...)
session.statement.UseBool(columns...)
return session
}
@ -60,25 +60,25 @@ func (session *Session) UseBool(columns ...string) *Session {
// distinct will not be cached because cache system need id,
// but distinct will not provide id
func (session *Session) Distinct(columns ...string) *Session {
session.Statement.Distinct(columns...)
session.statement.Distinct(columns...)
return session
}
// Omit Only not use the parameters as select or update columns
func (session *Session) Omit(columns ...string) *Session {
session.Statement.Omit(columns...)
session.statement.Omit(columns...)
return session
}
// Nullable Set null when column is zero-value and nullable for update
func (session *Session) Nullable(columns ...string) *Session {
session.Statement.Nullable(columns...)
session.statement.Nullable(columns...)
return session
}
// NoAutoTime means do not automatically give created field and updated field
// the current time on the current session temporarily
func (session *Session) NoAutoTime() *Session {
session.Statement.UseAutoTime = false
session.statement.UseAutoTime = false
return session
}

View File

@ -31,7 +31,39 @@ func TestSetExpr(t *testing.T) {
if testEngine.dialect.DBType() == core.MSSQL {
not = "~"
}
cnt, err = testEngine.SetExpr("show", not+" `show`").Id(1).Update(new(User))
cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(User))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}
func TestCols(t *testing.T) {
assert.NoError(t, prepareEngine())
type ColsTable struct {
Id int64
Col1 string
Col2 string
}
assertSync(t, new(ColsTable))
_, err := testEngine.Insert(&ColsTable{
Col1: "1",
Col2: "2",
})
assert.NoError(t, err)
sess := testEngine.ID(1)
_, err = sess.Cols("col1").Cols("col2").Update(&ColsTable{
Col1: "",
Col2: "",
})
assert.NoError(t, err)
var tb ColsTable
has, err := testEngine.ID(1).Get(&tb)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "", tb.Col1)
assert.EqualValues(t, "", tb.Col2)
}

View File

@ -17,25 +17,25 @@ func (session *Session) Sql(query string, args ...interface{}) *Session {
// SQL provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.
func (session *Session) SQL(query interface{}, args ...interface{}) *Session {
session.Statement.SQL(query, args...)
session.statement.SQL(query, args...)
return session
}
// Where provides custom query condition.
func (session *Session) Where(query interface{}, args ...interface{}) *Session {
session.Statement.Where(query, args...)
session.statement.Where(query, args...)
return session
}
// And provides custom query condition.
func (session *Session) And(query interface{}, args ...interface{}) *Session {
session.Statement.And(query, args...)
session.statement.And(query, args...)
return session
}
// Or provides custom query condition.
func (session *Session) Or(query interface{}, args ...interface{}) *Session {
session.Statement.Or(query, args...)
session.statement.Or(query, args...)
return session
}
@ -48,23 +48,23 @@ func (session *Session) Id(id interface{}) *Session {
// ID provides converting id as a query condition
func (session *Session) ID(id interface{}) *Session {
session.Statement.ID(id)
session.statement.ID(id)
return session
}
// In provides a query string like "id in (1, 2, 3)"
func (session *Session) In(column string, args ...interface{}) *Session {
session.Statement.In(column, args...)
session.statement.In(column, args...)
return session
}
// NotIn provides a query string like "id in (1, 2, 3)"
func (session *Session) NotIn(column string, args ...interface{}) *Session {
session.Statement.NotIn(column, args...)
session.statement.NotIn(column, args...)
return session
}
// Conds returns session query conditions
// Conds returns session query conditions except auto bean conditions
func (session *Session) Conds() builder.Cond {
return session.Statement.cond
return session.statement.cond
}

View File

@ -83,6 +83,11 @@ func TestBuilder(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.NotIn("col_name", "col1", "col2").Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 0, len(conds), "records should not exist")
// complex condtions
var where = builder.NewCond()
if true {
@ -222,7 +227,7 @@ func TestIn(t *testing.T) {
}
user := new(Userinfo)
has, err := testEngine.Id(ids[0]).Get(user)
has, err := testEngine.ID(ids[0]).Get(user)
if err != nil {
t.Error(err)
panic(err)
@ -260,3 +265,35 @@ func TestIn(t *testing.T) {
panic(err)
}
}
func TestFindAndCount(t *testing.T) {
assert.NoError(t, prepareEngine())
type FindAndCount struct {
Id int64
Name string
}
assert.NoError(t, testEngine.Sync2(new(FindAndCount)))
_, err := testEngine.Insert([]FindAndCount{
{
Name: "test1",
},
{
Name: "test2",
},
})
assert.NoError(t, err)
var results []FindAndCount
sess := testEngine.Where("name = ?", "test1")
conds := sess.Conds()
err = sess.Find(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
total, err := testEngine.Where(conds).Count(new(FindAndCount))
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
}

View File

@ -23,7 +23,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
var x time.Time
var err error
var parseLoc = session.Engine.DatabaseTZ
var parseLoc = session.engine.DatabaseTZ
if col.TimeZone != nil {
parseLoc = col.TimeZone
}
@ -34,27 +34,27 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil {
x = time.Unix(sd, 0)
session.Engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
session.Engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.Engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
session.Engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
session.Engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
session.Engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
session.Engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time {
if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ")
@ -62,13 +62,13 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
}
sdata = strings.TrimSpace(sdata)
if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:]
}
st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
session.Engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
outErr = fmt.Errorf("unsupported time format %v", sdata)
return
@ -77,7 +77,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err)
return
}
outTime = x.In(session.Engine.TZLocation)
outTime = x.In(session.engine.TZLocation)
return
}
@ -105,7 +105,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@ -119,7 +119,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@ -132,7 +132,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@ -156,7 +156,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 {
x = int64(data[0])
} else {
@ -204,8 +204,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.Statement.UseCascade {
table, err := session.Engine.autoMapType(*fieldValue)
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return err
}
@ -227,9 +227,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return err
}
@ -264,7 +262,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, &x)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@ -275,7 +273,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 {
err := json.Unmarshal(data, &x)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return err
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@ -347,7 +345,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int64(data[0])
} else {
@ -372,7 +370,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int(data[0])
} else {
@ -400,7 +398,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == core.MYSQL {
session.engine.dialect.DBType() == core.MYSQL {
if len(data) == 1 {
x = int32(data[0])
} else {
@ -428,7 +426,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int8(data[0])
} else {
@ -456,7 +454,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") {
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int16(data[0])
} else {
@ -488,9 +486,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
v = x
fieldValue.Set(reflect.ValueOf(&x))
default:
if session.Statement.UseCascade {
if session.statement.UseCascade {
structInter := reflect.New(fieldType.Elem())
table, err := session.Engine.autoMapType(structInter.Elem())
table, err := session.engine.autoMapType(structInter.Elem())
if err != nil {
return err
}
@ -510,9 +508,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return err
}
@ -569,7 +565,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if fieldValue.IsNil() {
return nil, nil
} else if !fieldValue.IsValid() {
session.Engine.logger.Warn("the field[", col.FieldName, "] is invalid")
session.engine.logger.Warn("the field[", col.FieldName, "] is invalid")
return nil, nil
} else {
// !nashtsai! deference pointer type to instance type
@ -587,7 +583,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
tf := session.Engine.formatColTime(col, t)
tf := session.engine.formatColTime(col, t)
return tf, nil
}
@ -597,7 +593,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return v.Value()
}
fieldTable, err := session.Engine.autoMapType(fieldValue)
fieldTable, err := session.engine.autoMapType(fieldValue)
if err != nil {
return nil, err
}
@ -611,14 +607,14 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return bytes, nil
@ -627,7 +623,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.Complex64, reflect.Complex128:
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
@ -639,7 +635,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
@ -652,7 +648,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
return 0, err
}
}

View File

@ -13,25 +13,25 @@ import (
)
func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil ||
session.Tx != nil {
if session.statement.RefTable == nil ||
session.tx != nil {
return ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
newsql := session.Statement.convertIDSQL(sqlStr)
newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
cacher := session.Engine.getCacher2(session.Statement.RefTable)
tableName := session.Statement.TableName()
cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
resultsSlice, err := session.query(newsql, args...)
resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil {
return err
}
@ -40,7 +40,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
for _, data := range resultsSlice {
var id int64
var pk core.PK = make([]interface{}, 0)
for _, col := range session.Statement.RefTable.PKColumns() {
for _, col := range session.statement.RefTable.PKColumns() {
if v, ok := data[col.Name]; !ok {
return errors.New("no id")
} else if col.SQLType.IsText() {
@ -59,34 +59,33 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
}
}
} /*else {
session.Engine.LogDebug("delete cache sql %v", newsql)
session.engine.LogDebug("delete cache sql %v", newsql)
cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/
for _, id := range ids {
session.Engine.logger.Debug("[cacheDelete] delete cache obj", tableName, id)
session.engine.logger.Debug("[cacheDelete] delete cache obj", tableName, id)
sid, err := id.ToString()
if err != nil {
return err
}
cacher.DelBean(tableName, sid)
}
session.Engine.logger.Debug("[cacheDelete] clear cache sql", tableName)
session.engine.logger.Debug("[cacheDelete] clear cache sql", tableName)
cacher.ClearIds(tableName)
return nil
}
// Delete records, bean's non-empty fields are conditions
func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
if err := session.Statement.setRefValue(rValue(bean)); err != nil {
if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err
}
var table = session.Statement.RefTable
var table = session.statement.RefTable
// handle before delete processors
for _, closure := range session.beforeClosures {
@ -98,15 +97,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
processor.BeforeDelete()
}
condSQL, condArgs, err := session.Statement.genConds(bean)
condSQL, condArgs, err := session.statement.genConds(bean)
if err != nil {
return 0, err
}
if len(condSQL) == 0 && session.Statement.LimitN == 0 {
if len(condSQL) == 0 && session.statement.LimitN == 0 {
return 0, ErrNeedDeletedCond
}
var tableName = session.Engine.Quote(session.Statement.TableName())
var tableName = session.engine.Quote(session.statement.TableName())
var deleteSQL string
if len(condSQL) > 0 {
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL)
@ -115,15 +114,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}
var orderSQL string
if len(session.Statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.Statement.OrderStr)
if len(session.statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
}
if session.Statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.Statement.LimitN)
if session.statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN)
}
if len(orderSQL) > 0 {
switch session.Engine.dialect.DBType() {
switch session.engine.dialect.DBType() {
case core.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
@ -148,7 +147,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
var realSQL string
argsForCache := make([]interface{}, 0, len(condArgs)*2)
if session.Statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled
if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
@ -159,12 +158,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v",
session.Engine.Quote(session.Statement.TableName()),
session.Engine.Quote(deletedColumn.Name),
session.engine.Quote(session.statement.TableName()),
session.engine.Quote(deletedColumn.Name),
condSQL)
if len(orderSQL) > 0 {
switch session.Engine.dialect.DBType() {
switch session.engine.dialect.DBType() {
case core.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
@ -187,12 +186,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}
}
// !oinume! Insert NowTime to the head of session.Statement.Params
// !oinume! Insert NowTime to the head of session.statement.Params
condArgs = append(condArgs, "")
paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
val, t := session.Engine.NowTime2(deletedColumn.SQLType.Name)
val, t := session.engine.NowTime2(deletedColumn.SQLType.Name)
condArgs[0] = val
var colName = deletedColumn.Name
@ -202,7 +201,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
})
}
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache {
if cacher := session.engine.getCacher2(session.statement.RefTable); cacher != nil && session.statement.UseCache {
session.cacheDelete(deleteSQL, argsForCache...)
}
@ -212,7 +211,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}
// handle after delete processors
if session.IsAutoCommit {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}

View File

@ -26,13 +26,27 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Delete(&UserinfoDelete{Uid: 1})
cnt, err = testEngine.Delete(&UserinfoDelete{Uid: user.Uid})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
user.Uid = 0
user.IsMan = true
has, err := testEngine.Id(1).Get(&user)
has, err := testEngine.ID(1).Get(&user)
assert.NoError(t, err)
assert.False(t, has)
cnt, err = testEngine.Insert(&user)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Where("id=?", user.Uid).Delete(&UserinfoDelete{})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
user.Uid = 0
user.IsMan = true
has, err = testEngine.ID(2).Get(&user)
assert.NoError(t, err)
assert.False(t, has)
}
@ -68,16 +82,16 @@ func TestDeleted(t *testing.T) {
// Test normal Get()
record1 := &Deleted{}
has, err := testEngine.Id(1).Get(record1)
has, err := testEngine.ID(1).Get(record1)
assert.NoError(t, err)
assert.True(t, has)
// Test Delete() with deleted
affected, err := testEngine.Id(1).Delete(&Deleted{})
affected, err := testEngine.ID(1).Delete(&Deleted{})
assert.NoError(t, err)
assert.EqualValues(t, 1, affected)
has, err = testEngine.Id(1).Get(&Deleted{})
has, err = testEngine.ID(1).Get(&Deleted{})
assert.NoError(t, err)
assert.False(t, has)
@ -87,17 +101,17 @@ func TestDeleted(t *testing.T) {
assert.EqualValues(t, 2, len(records2))
// Test no rows affected after Delete() again.
affected, err = testEngine.Id(1).Delete(&Deleted{})
affected, err = testEngine.ID(1).Delete(&Deleted{})
assert.NoError(t, err)
assert.EqualValues(t, 0, affected)
// Deleted.DeletedAt must not be updated.
affected, err = testEngine.Id(2).Update(&Deleted{Name: "2", DeletedAt: time.Now()})
affected, err = testEngine.ID(2).Update(&Deleted{Name: "2", DeletedAt: time.Now()})
assert.NoError(t, err)
assert.EqualValues(t, 1, affected)
record2 := &Deleted{}
has, err = testEngine.Id(2).Get(record2)
has, err = testEngine.ID(2).Get(record2)
assert.NoError(t, err)
assert.True(t, record2.DeletedAt.IsZero())
@ -108,7 +122,7 @@ func TestDeleted(t *testing.T) {
assert.EqualValues(t, 3, len(unscopedRecords1))
// Delete() must really delete a record with Unscoped()
affected, err = testEngine.Unscoped().Id(1).Delete(&Deleted{})
affected, err = testEngine.Unscoped().ID(1).Delete(&Deleted{})
assert.NoError(t, err)
assert.EqualValues(t, 1, affected)

View File

@ -23,11 +23,13 @@ const (
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct
func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.find(rowsSlicePtr, condiBean...)
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return errors.New("needs a pointer to a slice or a map")
@ -36,11 +38,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem()
var tp = tpStruct
if session.Statement.RefTable == nil {
if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem())
if err := session.Statement.setRefValue(pv.Elem()); err != nil {
if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err
}
} else {
@ -48,7 +50,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
} else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType)
if err := session.Statement.setRefValue(pv.Elem()); err != nil {
if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err
}
} else {
@ -56,31 +58,31 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
}
var table = session.Statement.RefTable
var table = session.statement.RefTable
var addedTableName = (len(session.Statement.JoinStr) > 0)
var addedTableName = (len(session.statement.JoinStr) > 0)
var autoCond builder.Cond
if tp == tpStruct {
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error
autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName)
autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName)
if err != nil {
return err
}
} else {
// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
// See https://github.com/go-xorm/xorm/issues/179
if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled
var colName = session.Engine.Quote(col.Name)
if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
var colName = session.engine.Quote(col.Name)
if addedTableName {
var nm = session.Statement.TableName()
if len(session.Statement.TableAlias) > 0 {
nm = session.Statement.TableAlias
var nm = session.statement.TableName()
if len(session.statement.TableAlias) > 0 {
nm = session.statement.TableAlias
}
colName = session.Engine.Quote(nm) + "." + colName
colName = session.engine.Quote(nm) + "." + colName
}
autoCond = session.Engine.CondDeleted(colName)
autoCond = session.engine.CondDeleted(colName)
}
}
}
@ -88,27 +90,27 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var sqlStr string
var args []interface{}
var err error
if session.Statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 {
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) <= 0 {
return ErrTableNotFound
}
var columnStr = session.Statement.ColumnStr
if len(session.Statement.selectStr) > 0 {
columnStr = session.Statement.selectStr
var columnStr = session.statement.ColumnStr
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.Statement.JoinStr == "" {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.Statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else {
columnStr = session.Statement.genColumnStr()
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.Statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else {
columnStr = "*"
}
@ -119,13 +121,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
}
condSQL, condArgs, err := builder.ToSQL(session.Statement.cond.And(autoCond))
session.statement.cond = session.statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return err
}
args = append(session.Statement.joinArgs, condArgs...)
sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL)
args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
if err != nil {
return err
}
@ -135,20 +138,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
args = append(args, args...)
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
if session.canCache() {
if cacher := session.Engine.getCacher2(table); cacher != nil &&
!session.Statement.IsDistinct &&
!session.Statement.unscoped {
if cacher := session.engine.getCacher2(table); cacher != nil &&
!session.statement.IsDistinct &&
!session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
if err != ErrCacheFailed {
return err
}
err = nil // !nashtsai! reset err to nil for ErrCacheFailed
session.Engine.logger.Warn("Cache Find Failed")
session.engine.logger.Warn("Cache Find Failed")
}
}
@ -156,21 +159,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
var rawRows *core.Rows
var err error
session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return err
}
defer rawRows.Close()
defer rows.Close()
fields, err := rawRows.Columns()
fields, err := rows.Columns()
if err != nil {
return err
}
@ -240,24 +235,24 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface())
tb, err := session.Engine.autoMapType(dataStruct)
tb, err := session.engine.autoMapType(dataStruct)
if err != nil {
return err
}
return session.rows2Beans(rawRows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
}
for rawRows.Next() {
for rows.Next() {
var newValue = newElemFunc(fields)
bean := newValue.Interface()
switch elemType.Kind() {
case reflect.Slice:
err = rawRows.ScanSlice(bean)
err = rows.ScanSlice(bean)
case reflect.Map:
err = rawRows.ScanMap(bean)
err = rows.ScanMap(bean)
default:
err = rawRows.Scan(bean)
err = rows.Scan(bean)
}
if err != nil {
@ -288,22 +283,22 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
newsql := session.Statement.convertIDSQL(sqlStr)
newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
tableName := session.Statement.TableName()
tableName := session.statement.TableName()
table := session.Statement.RefTable
cacher := session.Engine.getCacher2(table)
table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
rows, err := session.DB().Query(newsql, args...)
rows, err := session.NoCache().queryRows(newsql, args...)
if err != nil {
return err
}
@ -314,7 +309,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for rows.Next() {
i++
if i > 500 {
session.Engine.logger.Debug("[cacheFind] ids length > 500, no cache")
session.engine.logger.Debug("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed
}
var res = make([]string, len(table.PrimaryKeys))
@ -324,7 +319,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
pk[i], err = session.Engine.idTypeAssertion(col, res[i])
pk[i], err = session.engine.idTypeAssertion(col, res[i])
if err != nil {
return err
}
@ -333,13 +328,13 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ids = append(ids, pk)
}
session.Engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args)
session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return err
}
} else {
session.Engine.logger.Debug("[cacheFind] cache hit sql:", newsql, args)
session.engine.logger.Debug("[cacheFind] cache hit sql:", newsql, args)
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
@ -358,16 +353,16 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ides = append(ides, id)
ididxes[sid] = idx
} else {
session.Engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean)
session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean)
pk := session.Engine.IdOf(bean)
pk := session.engine.IdOf(bean)
xid, err := pk.ToString()
if err != nil {
return err
}
if sid != xid {
session.Engine.logger.Error("[cacheFind] error cache", xid, sid, bean)
session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean)
return ErrCacheFailed
}
temps[idx] = bean
@ -375,9 +370,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
if len(ides) > 0 {
newSession := session.Engine.NewSession()
defer newSession.Close()
slices := reflect.New(reflect.SliceOf(t))
beans := slices.Interface()
@ -387,18 +379,18 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ff = append(ff, ie[0])
}
newSession.In("`"+table.PrimaryKeys[0]+"`", ff...)
session.In("`"+table.PrimaryKeys[0]+"`", ff...)
} else {
for _, ie := range ides {
cond := builder.NewCond()
for i, name := range table.PrimaryKeys {
cond = cond.And(builder.Eq{"`" + name + "`": ie[i]})
}
newSession.Or(cond)
session.Or(cond)
}
}
err = newSession.NoCache().Find(beans)
err = session.NoCache().find(beans)
if err != nil {
return err
}
@ -409,7 +401,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
id, err := session.Engine.idOfV(rv)
id, err := session.engine.idOfV(rv)
if err != nil {
return err
}
@ -420,7 +412,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
bean := rv.Interface()
temps[ididxes[sid]] = bean
session.Engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps)
session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps)
cacher.PutBean(tableName, sid, bean)
}
}
@ -428,7 +420,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for j := 0; j < len(temps); j++ {
bean := temps[j]
if bean == nil {
session.Engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps)
session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps)
// return errors.New("cache error") // !nashtsai! no need to return error, but continue instead
continue
}

View File

@ -15,18 +15,22 @@ import (
// Get retrieve one record from database, bean's non-empty fields
// will be as conditions
func (session *Session) Get(bean interface{}) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.get(bean)
}
func (session *Session) get(bean interface{}) (bool, error) {
beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer")
return false, errors.New("needs a pointer to a value")
} else if beanValue.Elem().Kind() == reflect.Ptr {
return false, errors.New("a pointer to a pointer is not allowed")
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := session.Statement.setRefValue(beanValue.Elem()); err != nil {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
return false, err
}
}
@ -35,23 +39,23 @@ func (session *Session) Get(bean interface{}) (bool, error) {
var args []interface{}
var err error
if session.Statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 {
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) <= 0 {
return false, ErrTableNotFound
}
session.Statement.Limit(1)
sqlStr, args, err = session.Statement.genGetSQL(bean)
session.statement.Limit(1)
sqlStr, args, err = session.statement.genGetSQL(bean)
if err != nil {
return false, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
}
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil &&
!session.Statement.unscoped {
if cacher := session.engine.getCacher2(session.statement.RefTable); cacher != nil &&
!session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed {
return has, err
@ -63,50 +67,42 @@ func (session *Session) Get(bean interface{}) (bool, error) {
}
func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
session.queryPreprocess(&sqlStr, args...)
var rawRows *core.Rows
var err error
if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
defer rawRows.Close()
if !rawRows.Next() {
if !rows.Next() {
return false, nil
}
switch beanKind {
case reflect.Struct:
fields, err := rawRows.Columns()
fields, err := rows.Columns()
if err != nil {
// WARN: Alougth rawRows return true, but get fields failed
// WARN: Alougth rows return true, but get fields failed
return true, err
}
dataStruct := rValue(bean)
if err := session.Statement.setRefValue(dataStruct); err != nil {
if err := session.statement.setRefValue(dataStruct); err != nil {
return false, err
}
scanResults, err := session.row2Slice(rawRows, fields, len(fields), bean)
scanResults, err := session.row2Slice(rows, fields, len(fields), bean)
if err != nil {
return false, err
}
rawRows.Close()
// close it before covert data
rows.Close()
_, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.Statement.RefTable)
_, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.statement.RefTable)
case reflect.Slice:
err = rawRows.ScanSlice(bean)
err = rows.ScanSlice(bean)
case reflect.Map:
err = rawRows.ScanMap(bean)
err = rows.ScanMap(bean)
default:
err = rawRows.Scan(bean)
err = rows.Scan(bean)
}
return true, err
@ -118,22 +114,22 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
newsql := session.Statement.convertIDSQL(sqlStr)
newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" {
return false, ErrCacheFailed
}
cacher := session.Engine.getCacher2(session.Statement.RefTable)
tableName := session.Statement.TableName()
session.Engine.logger.Debug("[cacheGet] find sql:", newsql, args)
cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName()
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
table := session.Statement.RefTable
table := session.statement.RefTable
if err != nil {
var res = make([]string, len(table.PrimaryKeys))
rows, err := session.DB().Query(newsql, args...)
rows, err := session.NoCache().queryRows(newsql, args...)
if err != nil {
return false, err
}
@ -164,19 +160,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
}
ids = []core.PK{pk}
session.Engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return false, err
}
} else {
session.Engine.logger.Debug("[cacheGet] cache hit sql:", newsql)
session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql)
}
if len(ids) > 0 {
structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0]
session.Engine.logger.Debug("[cacheGet] get bean:", tableName, id)
session.engine.logger.Debug("[cacheGet] get bean:", tableName, id)
sid, err := id.ToString()
if err != nil {
return false, err
@ -189,10 +185,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return has, err
}
session.Engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
cacher.PutBean(tableName, sid, cacheBean)
} else {
session.Engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
has = true
}
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))

View File

@ -71,15 +71,18 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "28", valuesString["age"])
assert.Equal(t, "1.5", valuesString["money"])
var valuesInter = make(map[string]interface{})
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 5, len(valuesInter))
assert.EqualValues(t, 1, valuesInter["id"])
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"]))
assert.EqualValues(t, 28, valuesInter["age"])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
// for mymysql driver, interface{} will be []byte, so ignore it currently
if testEngine.dialect.DriverName() != "mymysql" {
var valuesInter = make(map[string]interface{})
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 5, len(valuesInter))
assert.EqualValues(t, 1, valuesInter["id"])
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"]))
assert.EqualValues(t, 28, valuesInter["age"])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
}
var valuesSliceString = make([]string, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceString)
@ -171,3 +174,23 @@ func TestGetSlice(t *testing.T) {
assert.False(t, has)
assert.Error(t, err)
}
func TestGetError(t *testing.T) {
assert.NoError(t, prepareEngine())
type GetError struct {
Uid int `xorm:"pk autoincr"`
IsMan bool
}
assertSync(t, new(GetError))
var info = new(GetError)
has, err := testEngine.Get(&info)
assert.False(t, has)
assert.Error(t, err)
has, err = testEngine.Get(info)
assert.False(t, has)
assert.NoError(t, err)
}

View File

@ -19,17 +19,16 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
defer session.resetStatement()
for _, bean := range beans {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
size := sliceValue.Len()
if size > 0 {
if session.Engine.SupportInsertMany() {
if session.engine.SupportInsertMany() {
cnt, err := session.innerInsertMulti(bean)
if err != nil {
return affected, err
@ -67,15 +66,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice")
}
if err := session.Statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
return 0, err
}
if len(session.Statement.TableName()) <= 0 {
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
table := session.Statement.RefTable
table := session.statement.RefTable
size := sliceValue.Len()
var colNames []string
@ -116,18 +115,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name)
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val)
var colName = col.Name
@ -135,7 +134,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -171,18 +170,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if session.Statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name)
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val)
var colName = col.Name
@ -190,7 +189,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.Statement.checkVersion {
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -214,25 +213,25 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
var statement string
if session.Engine.dialect.DBType() == core.ORACLE {
if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
session.Engine.QuoteStr())
session.engine.Quote(session.statement.TableName()),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr())
statement = fmt.Sprintf(sql,
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
session.Engine.QuoteStr(),
session.engine.Quote(session.statement.TableName()),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp))
} else {
statement = fmt.Sprintf(sql,
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
session.Engine.QuoteStr(),
session.engine.Quote(session.statement.TableName()),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(statement, args...)
@ -240,8 +239,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err
}
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.statement.TableName())
}
lenAfterClosures := len(session.afterClosures)
@ -249,7 +248,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor
if session.IsAutoCommit {
if session.isAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.afterClosures {
closure(elemValue)
@ -280,8 +279,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
// InsertMulti insert multiple records
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
@ -299,14 +297,14 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
}
func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.Statement.setRefValue(rValue(bean)); err != nil {
if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err
}
if len(session.Statement.TableName()) <= 0 {
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
table := session.Statement.RefTable
table := session.statement.RefTable
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
@ -318,12 +316,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
processor.BeforeInsert()
}
// --
colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false)
colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false)
if err != nil {
return 0, err
}
// insert expr columns, override if exists
exprColumns := session.Statement.getExpr()
exprColumns := session.statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns))
for _, v := range exprColumns {
// remove the expr columns
@ -351,21 +349,21 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var sqlStr string
if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
session.Engine.Quote(session.Statement.TableName()),
session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.Quote(", ")),
session.Engine.QuoteStr(),
session.engine.Quote(session.statement.TableName()),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.Quote(", ")),
session.engine.QuoteStr(),
colPlaces)
} else {
if session.Engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.Engine.Quote(session.Statement.TableName()))
if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(session.statement.TableName()))
} else {
sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.Engine.Quote(session.Statement.TableName()))
sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(session.statement.TableName()))
}
}
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.IsAutoCommit {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
@ -394,22 +392,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.query("select seq_atable.currval from dual", args...)
if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.statement.TableName())
}
if table.Version != "" && session.Statement.checkVersion {
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
@ -427,7 +425,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -437,24 +435,24 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
} else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
} else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
res, err := session.query(sqlStr, args...)
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
res, err := session.queryBytes(sqlStr, args...)
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.statement.TableName())
}
if table.Version != "" && session.Statement.checkVersion {
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
@ -472,7 +470,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -490,14 +488,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.statement.TableName())
}
if table.Version != "" && session.Statement.checkVersion {
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
@ -515,7 +513,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -532,8 +530,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
func (session *Session) InsertOne(bean interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
@ -541,15 +538,15 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
}
func (session *Session) cacheInsert(tables ...string) error {
if session.Statement.RefTable == nil {
if session.statement.RefTable == nil {
return ErrCacheFailed
}
table := session.Statement.RefTable
cacher := session.Engine.getCacher2(table)
table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
for _, t := range tables {
session.Engine.logger.Debug("[cache] clear sql:", t)
session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t)
}

View File

@ -19,6 +19,10 @@ func (session *Session) Rows(bean interface{}) (*Rows, error) {
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct
func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
if session.isAutoClose {
defer session.Close()
}
rows, err := session.Rows(bean)
if err != nil {
return err

View File

@ -127,7 +127,7 @@ func TestIntId(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&IntId{})
cnt, err = testEngine.ID(bean.Id).Delete(&IntId{})
if err != nil {
t.Error(err)
panic(err)
@ -202,7 +202,7 @@ func TestInt16Id(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&Int16Id{})
cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{})
if err != nil {
t.Error(err)
panic(err)
@ -277,7 +277,7 @@ func TestInt32Id(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&Int32Id{})
cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{})
if err != nil {
t.Error(err)
panic(err)
@ -366,7 +366,7 @@ func TestUintId(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&UintId{})
cnt, err = testEngine.ID(bean.Id).Delete(&UintId{})
if err != nil {
t.Error(err)
panic(err)
@ -441,7 +441,7 @@ func TestUint16Id(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&Uint16Id{})
cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{})
if err != nil {
t.Error(err)
panic(err)
@ -516,7 +516,7 @@ func TestUint32Id(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&Uint32Id{})
cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{})
if err != nil {
t.Error(err)
panic(err)
@ -604,7 +604,7 @@ func TestUint64Id(t *testing.T) {
panic(errors.New("should be equal"))
}
cnt, err = testEngine.Id(bean.Id).Delete(&Uint64Id{})
cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{})
if err != nil {
t.Error(err)
panic(err)
@ -679,7 +679,7 @@ func TestStringPK(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(bean.Id).Delete(&StringPK{})
cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{})
if err != nil {
t.Error(err)
panic(err)
@ -725,7 +725,7 @@ func TestCompositeKey(t *testing.T) {
}
var compositeKeyVal CompositeKey
has, err := testEngine.Id(core.PK{11, 22}).Get(&compositeKeyVal)
has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if !has {
@ -734,7 +734,7 @@ func TestCompositeKey(t *testing.T) {
var compositeKeyVal2 CompositeKey
// test passing PK ptr, this test seem failed withCache
has, err = testEngine.Id(&core.PK{11, 22}).Get(&compositeKeyVal2)
has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2)
if err != nil {
t.Error(err)
} else if !has {
@ -781,14 +781,14 @@ func TestCompositeKey(t *testing.T) {
}
compositeKeyVal = CompositeKey{UpdateStr: "test1"}
cnt, err = testEngine.Id(core.PK{11, 22}).Update(&compositeKeyVal)
cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}"))
}
cnt, err = testEngine.Id(core.PK{11, 22}).Delete(&CompositeKey{})
cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
@ -832,7 +832,7 @@ func TestCompositeKey2(t *testing.T) {
}
var user User
has, err := testEngine.Id(core.PK{"11", 22}).Get(&user)
has, err := testEngine.ID(core.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -840,7 +840,7 @@ func TestCompositeKey2(t *testing.T) {
}
// test passing PK ptr, this test seem failed withCache
has, err = testEngine.Id(&core.PK{"11", 22}).Get(&user)
has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -848,14 +848,14 @@ func TestCompositeKey2(t *testing.T) {
}
user = User{NickName: "test1"}
cnt, err = testEngine.Id(core.PK{"11", 22}).Update(&user)
cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}"))
}
cnt, err = testEngine.Id(core.PK{"11", 22}).Delete(&User{})
cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
@ -900,7 +900,7 @@ func TestCompositeKey3(t *testing.T) {
}
var user UserPK2
has, err := testEngine.Id(core.PK{"11", 22}).Get(&user)
has, err := testEngine.ID(core.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -908,7 +908,7 @@ func TestCompositeKey3(t *testing.T) {
}
// test passing PK ptr, this test seem failed withCache
has, err = testEngine.Id(&core.PK{"11", 22}).Get(&user)
has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -916,14 +916,14 @@ func TestCompositeKey3(t *testing.T) {
}
user = UserPK2{NickName: "test1"}
cnt, err = testEngine.Id(core.PK{"11", 22}).Update(&user)
cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}"))
}
cnt, err = testEngine.Id(core.PK{"11", 22}).Delete(&UserPK2{})
cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
@ -1007,7 +1007,7 @@ func TestMyIntId(t *testing.T) {
panic(errors.New("should be equal"))
}
cnt, err = testEngine.Id(bean.ID).Delete(&MyIntPK{})
cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{})
if err != nil {
t.Error(err)
panic(err)
@ -1095,7 +1095,7 @@ func TestMyStringId(t *testing.T) {
panic(errors.New("should be equal"))
}
cnt, err = testEngine.Id(bean.ID).Delete(&MyStringPK{})
cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{})
if err != nil {
t.Error(err)
panic(err)

View File

@ -6,77 +6,77 @@ package xorm
import (
"database/sql"
"fmt"
"reflect"
"strconv"
"time"
"github.com/go-xorm/core"
)
func (session *Session) query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
session.queryPreprocess(&sqlStr, paramStr...)
if session.IsAutoCommit {
return session.innerQuery2(sqlStr, paramStr...)
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
}
return session.txQuery(session.Tx, sqlStr, paramStr...)
session.lastSQL = *sqlStr
session.lastSQLArgs = paramStr
}
func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) {
defer session.resetStatement()
session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
defer rows.Close()
return rows2maps(rows)
}
func (session *Session) innerQuery(sqlStr string, params ...interface{}) (*core.Stmt, *core.Rows, error) {
var callback func() (*core.Stmt, *core.Rows, error)
if session.prepareStmt {
callback = func() (*core.Stmt, *core.Rows, error) {
if session.isAutoCommit {
if session.prepareStmt {
// don't clear stmt since session will cache them
stmt, err := session.doPrepare(sqlStr)
if err != nil {
return nil, nil, err
return nil, err
}
rows, err := stmt.Query(params...)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
} else {
callback = func() (*core.Stmt, *core.Rows, error) {
rows, err := session.DB().Query(sqlStr, params...)
if err != nil {
return nil, nil, err
}
return nil, rows, err
}
}
stmt, rows, err := session.Engine.logSQLQueryTime(sqlStr, params, callback)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2map(rows, fields)
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := session.DB().Query(sqlStr, args...)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
return rows, nil
}
return resultsSlice, nil
rows, err := session.tx.Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
return core.NewRow(session.queryRows(sqlStr, args...))
}
func value2Bytes(rawValue *reflect.Value) (data []byte, err error) {
@ -104,7 +104,7 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
result[key] = []byte{}
continue
}
@ -117,34 +117,13 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
return result, nil
}
func (session *Session) innerQuery2(sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
_, rows, err := session.innerQuery(sqlStr, params...)
if rows != nil {
defer rows.Close()
}
if err != nil {
return nil, err
}
return rows2maps(rows)
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
return session.query(sqlStr, paramStr...)
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2mapStr(rows, fields)
result, err := row2map(rows, fields)
if err != nil {
return nil, err
}
@ -154,122 +133,45 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error)
return resultsSlice, nil
}
func reflect2value(rawValue *reflect.Value) (str string, err error) {
aa := reflect.TypeOf((*rawValue).Interface())
vv := reflect.ValueOf((*rawValue).Interface())
switch aa.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
case reflect.String:
str = vv.String()
case reflect.Array, reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
data := rawValue.Interface().([]byte)
str = string(data)
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(core.TimeType) {
str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
case reflect.Complex128, reflect.Complex64:
str = fmt.Sprintf("%v", vv.Complex())
/* TODO: unsupported types below
case reflect.Map:
case reflect.Ptr:
case reflect.Uintptr:
case reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
return
}
func value2String(rawValue *reflect.Value) (data string, err error) {
data, err = reflect2value(rawValue)
if err != nil {
return
}
return
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string]string, error) {
rows, err := tx.Query(sqlStr, params...)
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
return rows2maps(rows)
}
func query2(db *core.DB, sqlStr string, params ...interface{}) ([]map[string]string, error) {
rows, err := db.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit {
return query2(session.DB(), sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
if !session.isAutoCommit {
return session.tx.Exec(sqlStr, args...)
}
return txQuery2(session.Tx, sqlStr, args...)
}
// Execute sql
func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) {
if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr)
if err != nil {
@ -286,33 +188,9 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul
return session.DB().Exec(sqlStr, args...)
}
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.Engine.dialect.Filters() {
// TODO: for table name, it's no need to RefTable
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
session.saveLastSQL(sqlStr, args...)
return session.Engine.logSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
if session.IsAutoCommit {
// FIXME: oci8 can not auto commit (github.com/mattn/go-oci8)
if session.Engine.dialect.DBType() == core.ORACLE {
session.Begin()
r, err := session.Tx.Exec(sqlStr, args...)
session.Commit()
return r, err
}
return session.innerExec(sqlStr, args...)
}
return session.Tx.Exec(sqlStr, args...)
})
}
// Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}

View File

@ -7,42 +7,10 @@ package xorm
import (
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestQueryString(t *testing.T) {
assert.NoError(t, prepareEngine())
type GetVar struct {
Id int64 `xorm:"autoincr pk"`
Msg string `xorm:"varchar(255)"`
Age int
Money float32
Created time.Time `xorm:"created"`
}
assert.NoError(t, testEngine.Sync2(new(GetVar)))
var data = GetVar{
Msg: "hi",
Age: 28,
Money: 1.5,
}
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
records, err := testEngine.QueryString("select * from get_var")
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
assert.Equal(t, "1", records[0]["id"])
assert.Equal(t, "hi", records[0]["msg"])
assert.Equal(t, "28", records[0]["age"])
assert.Equal(t, "1.5", records[0]["money"])
}
func TestQuery(t *testing.T) {
assert.NoError(t, prepareEngine())

View File

@ -16,42 +16,50 @@ import (
// Ping test if database is ok
func (session *Session) Ping() error {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().Ping()
}
// CreateTable create a table according a bean
func (session *Session) CreateTable(bean interface{}) error {
v := rValue(bean)
if err := session.Statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
return session.createOneTable()
return session.createTable(bean)
}
func (session *Session) createTable(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqlStr := session.statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
}
// CreateIndexes create indexes
func (session *Session) CreateIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.Statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
sqls := session.Statement.genIndexSQL()
return session.createIndexes(bean)
}
func (session *Session) createIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genIndexSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -63,17 +71,19 @@ func (session *Session) CreateIndexes(bean interface{}) error {
// CreateUniques create uniques
func (session *Session) CreateUniques(bean interface{}) error {
if session.isAutoClose {
defer session.Close()
}
return session.createUniques(bean)
}
func (session *Session) createUniques(bean interface{}) error {
v := rValue(bean)
if err := session.Statement.setRefValue(v); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
sqls := session.Statement.genUniqueSQL()
sqls := session.statement.genUniqueSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -83,25 +93,22 @@ func (session *Session) CreateUniques(bean interface{}) error {
return nil
}
func (session *Session) createOneTable() error {
sqlStr := session.Statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
}
// DropIndexes drop indexes
func (session *Session) DropIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.Statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
sqls := session.Statement.genDelIndexSQL()
return session.dropIndexes(bean)
}
func (session *Session) dropIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genDelIndexSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -113,15 +120,23 @@ func (session *Session) DropIndexes(bean interface{}) error {
// DropTable drop table will drop table if exist, if drop failed, it will return error
func (session *Session) DropTable(beanOrTableName interface{}) error {
tableName, err := session.Engine.tableName(beanOrTableName)
if session.isAutoClose {
defer session.Close()
}
return session.dropTable(beanOrTableName)
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return err
}
var needDrop = true
if !session.Engine.dialect.SupportDropIfExists() {
sqlStr, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.queryBytes(sqlStr, args...)
if err != nil {
return err
}
@ -129,7 +144,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
if needDrop {
sqlStr := session.Engine.Dialect().DropTableSql(tableName)
sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err = session.exec(sqlStr)
return err
}
@ -138,7 +153,11 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
// IsTableExist if a table is exist
func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
tableName, err := session.Engine.tableName(beanOrTableName)
if session.isAutoClose {
defer session.Close()
}
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return false, err
}
@ -147,12 +166,8 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
}
func (session *Session) isTableExist(tableName string) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
sqlStr, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err
}
@ -162,6 +177,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
t := v.Type()
if t.Kind() == reflect.String {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean)
@ -171,15 +189,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
}
func (session *Session) isTableEmpty(tableName string) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName))
err := session.DB().QueryRow(sqlStr).Scan(&total)
session.saveLastSQL(sqlStr)
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
err = nil
@ -192,12 +204,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
// find if index is exist according cols
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
indexes, err := session.Engine.dialect.GetIndexes(tableName)
indexes, err := session.engine.dialect.GetIndexes(tableName)
if err != nil {
return false, err
}
@ -214,43 +221,34 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
}
func (session *Session) addColumn(colName string) error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
col := session.Statement.RefTable.GetColumn(colName)
sql, args := session.Statement.genAddColumnStr(col)
col := session.statement.RefTable.GetColumn(colName)
sql, args := session.statement.genAddColumnStr(col)
_, err := session.exec(sql, args...)
return err
}
func (session *Session) addIndex(tableName, idxName string) error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
index := session.Statement.RefTable.Indexes[idxName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
index := session.statement.RefTable.Indexes[idxName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
_, err := session.exec(sqlStr)
return err
}
func (session *Session) addUnique(tableName, uqeName string) error {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
index := session.Statement.RefTable.Indexes[uqeName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
index := session.statement.RefTable.Indexes[uqeName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
_, err := session.exec(sqlStr)
return err
}
// Sync2 synchronize structs to database tables
func (session *Session) Sync2(beans ...interface{}) error {
engine := session.Engine
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.DBMetas()
if err != nil {
@ -277,17 +275,17 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
if oriTable == nil {
err = session.StoreEngine(session.Statement.StoreEngine).CreateTable(bean)
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return err
}
err = session.CreateUniques(bean)
err = session.createUniques(bean)
if err != nil {
return err
}
err = session.CreateIndexes(bean)
err = session.createIndexes(bean)
if err != nil {
return err
}
@ -312,7 +310,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType)
@ -322,7 +320,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
}
}
} else {
@ -336,7 +334,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
}
}
}
@ -349,10 +347,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
} else {
session := engine.NewSession()
session.Statement.RefTable = table
session.Statement.tableName = tbName
defer session.Close()
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addColumn(col.Name)
}
if err != nil {
@ -376,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex)
_, err = engine.Exec(sql)
_, err = session.exec(sql)
if err != nil {
return err
}
@ -392,7 +388,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2)
_, err = engine.Exec(sql)
_, err = session.exec(sql)
if err != nil {
return err
}
@ -401,16 +397,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames {
if index.Type == core.UniqueType {
session := engine.NewSession()
session.Statement.RefTable = table
session.Statement.tableName = tbName
defer session.Close()
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addUnique(tbName, name)
} else if index.Type == core.IndexType {
session := engine.NewSession()
session.Statement.RefTable = table
session.Statement.tableName = tbName
defer session.Close()
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addIndex(tbName, name)
}
if err != nil {

View File

@ -1,187 +0,0 @@
// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import "database/sql"
// Count counts the records. bean's non-empty fields
// are conditions.
func (session *Session) Count(bean ...interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if session.Statement.RawSQL == "" {
if len(bean) == 0 {
return 0, ErrTableNotFound
}
sqlStr, args, err = session.Statement.genCountSQL(bean[0])
if err != nil {
return 0, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var total int64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&total)
}
if err == sql.ErrNoRows || err == nil {
return total, nil
}
return 0, err
}
// Sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) Sum(bean interface{}, columnName string) (float64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
if err != nil {
return 0, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res float64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return 0, err
}
// SumInt call sum some column. bean's non-empty fields are conditions.
func (session *Session) SumInt(bean interface{}, columnName string) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
if err != nil {
return 0, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res int64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return 0, err
}
// Sums call sum some columns. bean's non-empty fields are conditions.
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
if err != nil {
return nil, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res = make([]float64, len(columnNames), len(columnNames))
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return nil, err
}
// SumsInt sum specify columns and return as []int64 instead of []float64
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
if err != nil {
return nil, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res = make([]int64, len(columnNames), len(columnNames))
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return nil, err
}

View File

@ -1,152 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"strconv"
"testing"
"github.com/go-xorm/builder"
"github.com/stretchr/testify/assert"
)
func isFloatEq(i, j float64, precision int) bool {
return fmt.Sprintf("%."+strconv.Itoa(precision)+"f", i) == fmt.Sprintf("%."+strconv.Itoa(precision)+"f", j)
}
func TestSum(t *testing.T) {
assert.NoError(t, prepareEngine())
type SumStruct struct {
Int int
Float float32
}
var (
cases = []SumStruct{
{1, 6.2},
{2, 5.3},
{92, -0.2},
}
)
var i int
var f float32
for _, v := range cases {
i += v.Int
f += v.Float
}
assert.NoError(t, testEngine.Sync2(new(SumStruct)))
cnt, err := testEngine.Insert(cases)
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)
colInt := testEngine.ColumnMapper.Obj2Table("Int")
colFloat := testEngine.ColumnMapper.Obj2Table("Float")
sumInt, err := testEngine.Sum(new(SumStruct), colInt)
assert.NoError(t, err)
assert.EqualValues(t, int(sumInt), i)
sumFloat, err := testEngine.Sum(new(SumStruct), colFloat)
assert.NoError(t, err)
assert.Condition(t, func() bool {
return isFloatEq(sumFloat, float64(f), 2)
})
sums, err := testEngine.Sums(new(SumStruct), colInt, colFloat)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(sums))
assert.EqualValues(t, i, int(sums[0]))
assert.Condition(t, func() bool {
return isFloatEq(sums[1], float64(f), 2)
})
sumsInt, err := testEngine.SumsInt(new(SumStruct), colInt)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(sumsInt))
assert.EqualValues(t, i, int(sumsInt[0]))
}
func TestSumCustomColumn(t *testing.T) {
assert.NoError(t, prepareEngine())
type SumStruct struct {
Int int
Float float32
}
var (
cases = []SumStruct{
{1, 6.2},
{2, 5.3},
{92, -0.2},
}
)
assert.NoError(t, testEngine.Sync2(new(SumStruct)))
cnt, err := testEngine.Insert(cases)
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)
sumInt, err := testEngine.Sum(new(SumStruct),
"CASE WHEN `int` <= 2 THEN `int` ELSE 0 END")
assert.NoError(t, err)
assert.EqualValues(t, 3, int(sumInt))
}
func TestCount(t *testing.T) {
assert.NoError(t, prepareEngine())
type UserinfoCount struct {
Departname string
}
assert.NoError(t, testEngine.Sync2(new(UserinfoCount)))
colName := testEngine.ColumnMapper.Obj2Table("Departname")
var cond builder.Cond = builder.Eq{
"`" + colName + "`": "dev",
}
total, err := testEngine.Where(cond).Count(new(UserinfoCount))
assert.NoError(t, err)
assert.EqualValues(t, 0, total)
cnt, err := testEngine.Insert(&UserinfoCount{
Departname: "dev",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
total, err = testEngine.Where(cond).Count(new(UserinfoCount))
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
}
func TestSQLCount(t *testing.T) {
assert.NoError(t, prepareEngine())
type UserinfoCount2 struct {
Id int64
Departname string
}
type UserinfoBooks struct {
Id int64
Pid int64
IsOpen bool
}
assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
Count()
assert.NoError(t, err)
assert.EqualValues(t, 0, total)
}

View File

@ -6,14 +6,14 @@ package xorm
// Begin a transaction
func (session *Session) Begin() error {
if session.IsAutoCommit {
if session.isAutoCommit {
tx, err := session.DB().Begin()
if err != nil {
return err
}
session.IsAutoCommit = false
session.IsCommitedOrRollbacked = false
session.Tx = tx
session.isAutoCommit = false
session.isCommitedOrRollbacked = false
session.tx = tx
session.saveLastSQL("BEGIN TRANSACTION")
}
return nil
@ -21,25 +21,23 @@ func (session *Session) Begin() error {
// Rollback When using transaction, you can rollback if any error
func (session *Session) Rollback() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
session.saveLastSQL(session.Engine.dialect.RollBackStr())
session.IsCommitedOrRollbacked = true
return session.Tx.Rollback()
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL(session.engine.dialect.RollBackStr())
session.isCommitedOrRollbacked = true
return session.tx.Rollback()
}
return nil
}
// Commit When using transaction, Commit will commit all operations.
func (session *Session) Commit() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL("COMMIT")
session.IsCommitedOrRollbacked = true
session.isCommitedOrRollbacked = true
var err error
if err = session.Tx.Commit(); err == nil {
if err = session.tx.Commit(); err == nil {
// handle processors after tx committed
closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) {
if closuresPtr != nil {
for _, closure := range *closuresPtr {
closure(bean)

View File

@ -16,19 +16,19 @@ import (
)
func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil ||
session.Tx != nil {
if session.statement.RefTable == nil ||
session.tx != nil {
return ErrCacheFailed
}
oldhead, newsql := session.Statement.convertUpdateSQL(sqlStr)
oldhead, newsql := session.statement.convertUpdateSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
for _, filter := range session.Engine.dialect.Filters() {
newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable)
for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql, session.engine.dialect, session.statement.RefTable)
}
session.Engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
var nStart int
if len(args) > 0 {
@ -39,13 +39,13 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
nStart = strings.Count(oldhead, "$")
}
}
table := session.Statement.RefTable
cacher := session.Engine.getCacher2(table)
tableName := session.Statement.TableName()
session.Engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
tableName := session.statement.TableName()
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil {
rows, err := session.DB().Query(newsql, args[nStart:]...)
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
if err != nil {
return err
}
@ -75,9 +75,9 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
ids = append(ids, pk)
}
session.Engine.logger.Debug("[cacheUpdate] find updated id", ids)
session.engine.logger.Debug("[cacheUpdate] find updated id", ids)
} /*else {
session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/
@ -103,36 +103,36 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.Engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1))
} else if strings.Contains(colName, session.engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
} else {
session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
}
if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)
if err != nil {
session.Engine.logger.Error(err)
session.engine.logger.Error(err)
} else {
session.Engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.Statement.checkVersion {
session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.statement.checkVersion {
fieldValue.SetInt(fieldValue.Int() + 1)
} else {
fieldValue.Set(reflect.ValueOf(args[idx]))
}
}
} else {
session.Engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
colName, table.Name)
}
}
session.Engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
cacher.PutBean(tableName, sid, bean)
}
}
session.Engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
cacher.ClearIds(tableName)
return nil
}
@ -144,8 +144,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
// You should call UseBool if you have bool to use.
// 2.float32 & float64 may be not inexact as conditions
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
if session.isAutoClose {
defer session.Close()
}
@ -169,21 +168,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
if isStruct {
if err := session.Statement.setRefValue(v); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return 0, err
}
if len(session.Statement.TableName()) <= 0 {
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
if session.Statement.ColumnStr == "" {
colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false,
false, false, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.mustColumnMap, session.Statement.nullableMap,
session.Statement.columnMap, true, session.Statement.unscoped)
if session.statement.ColumnStr == "" {
colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false,
false, false, session.statement.allUseBool, session.statement.useAllCols,
session.statement.mustColumnMap, session.statement.nullableMap,
session.statement.columnMap, true, session.statement.unscoped)
} else {
colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true)
colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true)
if err != nil {
return 0, err
}
@ -194,19 +193,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
bValue := reflect.Indirect(reflect.ValueOf(bean))
for _, v := range bValue.MapKeys() {
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
colNames = append(colNames, session.engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface())
}
} else {
return 0, ErrParamsType
}
table := session.Statement.RefTable
table := session.statement.RefTable
if session.Statement.UseAutoTime && table != nil && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
if session.statement.UseAutoTime && table != nil && table.Updated != "" {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn()
val, t := session.Engine.NowTime2(col.SQLType.Name)
val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val)
var colName = col.Name
@ -219,45 +218,44 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
//for update action to like "column = column + ?"
incColumns := session.Statement.getInc()
incColumns := session.statement.getInc()
for _, v := range incColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?")
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg)
}
//for update action to like "column = column - ?"
decColumns := session.Statement.getDec()
decColumns := session.statement.getDec()
for _, v := range decColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?")
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg)
}
//for update action to like "column = expression"
exprColumns := session.Statement.getExpr()
exprColumns := session.statement.getExpr()
for _, v := range exprColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
}
if err = session.Statement.processIDParam(); err != nil {
if err = session.statement.processIDParam(); err != nil {
return 0, err
}
var autoCond builder.Cond
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error
autoCond, err = session.Statement.buildConds(session.Statement.RefTable, condiBean[0], true, true, false, true, false)
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
}
st := session.Statement
defer session.resetStatement()
st := &session.statement
var sqlStr string
var condArgs []interface{}
var condSQL string
cond := session.Statement.cond.And(autoCond)
cond := session.statement.cond.And(autoCond)
var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion)
var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion)
var verValue *reflect.Value
if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean)
@ -265,8 +263,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err
}
cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
}
condSQL, condArgs, err = builder.ToSQL(cond)
@ -290,7 +288,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} else if st.Engine.dialect.DBType() == core.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
session.engine.Quote(session.statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
@ -301,7 +299,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} else if st.Engine.dialect.DBType() == core.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
session.engine.Quote(session.statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
@ -315,7 +313,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...)
session.engine.Quote(session.statement.TableName()), condSQL), condArgs...)
condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
@ -330,9 +328,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
if len(colNames) <= 0 {
return 0, errors.New("No content found to be updated")
}
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
top,
session.Engine.Quote(session.Statement.TableName()),
session.engine.Quote(session.statement.TableName()),
strings.Join(colNames, ", "),
condSQL)
@ -346,19 +348,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
if table != nil {
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
cacher.ClearIds(session.Statement.TableName())
cacher.ClearBeans(session.Statement.TableName())
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
cacher.ClearIds(session.statement.TableName())
cacher.ClearBeans(session.statement.TableName())
}
}
// handle after update processors
if session.IsAutoCommit {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
session.Engine.logger.Debug("[event]", session.Statement.TableName(), " has after update processor")
session.engine.logger.Debug("[event]", session.statement.TableName(), " has after update processor")
processor.AfterUpdate()
}
} else {

View File

@ -298,7 +298,7 @@ func TestUpdate1(t *testing.T) {
// update by id
user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.Id(ori.Uid).Update(&user)
cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil {
t.Error(err)
panic(err)
@ -311,7 +311,7 @@ func TestUpdate1(t *testing.T) {
}
condi := Condi{"username": "zzz", "departname": ""}
cnt, err = testEngine.Table(&user).Id(ori.Uid).Update(&condi)
cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil {
t.Error(err)
panic(err)
@ -351,7 +351,7 @@ func TestUpdate1(t *testing.T) {
}
userID := user.Uid
has, err := testEngine.Id(userID).
has, err := testEngine.ID(userID).
And("username = ?", user.Username).
And("height = ?", user.Height).
And("departname = ?", "").
@ -369,7 +369,7 @@ func TestUpdate1(t *testing.T) {
}
updatedUser := &Userinfo{Username: "null data"}
cnt, err = testEngine.Id(userID).
cnt, err = testEngine.ID(userID).
Nullable("height", "departname", "is_man", "created").
Update(updatedUser)
if err != nil {
@ -382,7 +382,7 @@ func TestUpdate1(t *testing.T) {
panic(err)
}
has, err = testEngine.Id(userID).
has, err = testEngine.ID(userID).
And("username = ?", updatedUser.Username).
And("height IS NULL").
And("departname IS NULL").
@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(userID).Delete(&Userinfo{})
cnt, err = testEngine.ID(userID).Delete(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
@ -445,7 +445,7 @@ func TestUpdate1(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(a.Id).Update(&Article{Name: "6"})
cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil {
t.Error(err)
panic(err)
@ -474,14 +474,14 @@ func TestUpdate1(t *testing.T) {
}
col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.Id(col2.Id).AllCols().Update(col2)
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
col3 := &UpdateAllCols{}
has, err = testEngine.Id(col2.Id).Get(col3)
has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
@ -519,14 +519,14 @@ func TestUpdate1(t *testing.T) {
col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.ColumnMapper.Obj2Table("Bool")
stringStr := testEngine.ColumnMapper.Obj2Table("String")
_, err = testEngine.Id(col2.Id).MustCols(boolStr, stringStr).Update(col2)
_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
col3 := &UpdateMustCols{}
has, err := testEngine.Id(col2.Id).Get(col3)
has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
@ -561,17 +561,27 @@ func TestUpdateIncrDecr(t *testing.T) {
colName := testEngine.ColumnMapper.Obj2Table("Cnt")
cnt, err := testEngine.Id(col1.Id).Incr(colName).Update(col1)
cnt, err := testEngine.ID(col1.Id).Incr(colName).Update(col1)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
newCol := new(UpdateIncr)
has, err := testEngine.Id(col1.Id).Get(newCol)
has, err := testEngine.ID(col1.Id).Get(newCol)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, newCol.Cnt)
cnt, err = testEngine.Id(col1.Id).Cols(colName).Incr(colName).Update(col1)
cnt, err = testEngine.ID(col1.Id).Decr(colName).Update(col1)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
newCol = new(UpdateIncr)
has, err = testEngine.ID(col1.Id).Get(newCol)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, newCol.Cnt)
cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}
@ -616,12 +626,12 @@ func TestUpdateUpdated(t *testing.T) {
}
ci := &UpdatedUpdate{}
_, err = testEngine.Id(1).Update(ci)
_, err = testEngine.ID(1).Update(ci)
if err != nil {
t.Fatal(err)
}
has, err := testEngine.Id(1).Get(di)
has, err := testEngine.ID(1).Get(di)
if err != nil {
t.Fatal(err)
}
@ -644,11 +654,11 @@ func TestUpdateUpdated(t *testing.T) {
t.Fatal(err)
}
ci2 := &UpdatedUpdate2{}
_, err = testEngine.Id(1).Update(ci2)
_, err = testEngine.ID(1).Update(ci2)
if err != nil {
t.Fatal(err)
}
has, err = testEngine.Id(1).Get(di2)
has, err = testEngine.ID(1).Get(di2)
if err != nil {
t.Fatal(err)
}
@ -671,12 +681,12 @@ func TestUpdateUpdated(t *testing.T) {
t.Fatal(err)
}
ci3 := &UpdatedUpdate3{}
_, err = testEngine.Id(1).Update(ci3)
_, err = testEngine.ID(1).Update(ci3)
if err != nil {
t.Fatal(err)
}
has, err = testEngine.Id(1).Get(di3)
has, err = testEngine.ID(1).Get(di3)
if err != nil {
t.Fatal(err)
}
@ -700,12 +710,12 @@ func TestUpdateUpdated(t *testing.T) {
}
ci4 := &UpdatedUpdate4{}
_, err = testEngine.Id(1).Update(ci4)
_, err = testEngine.ID(1).Update(ci4)
if err != nil {
t.Fatal(err)
}
has, err = testEngine.Id(1).Get(di4)
has, err = testEngine.ID(1).Get(di4)
if err != nil {
t.Fatal(err)
}
@ -728,12 +738,12 @@ func TestUpdateUpdated(t *testing.T) {
t.Fatal(err)
}
ci5 := &UpdatedUpdate5{}
_, err = testEngine.Id(1).Update(ci5)
_, err = testEngine.ID(1).Update(ci5)
if err != nil {
t.Fatal(err)
}
has, err = testEngine.Id(1).Get(di5)
has, err = testEngine.ID(1).Get(di5)
if err != nil {
t.Fatal(err)
}
@ -786,7 +796,7 @@ func TestUpdateSameMapper(t *testing.T) {
}
// update by id
user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.Id(ori.Uid).Update(&user)
cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil {
t.Error(err)
panic(err)
@ -799,7 +809,7 @@ func TestUpdateSameMapper(t *testing.T) {
}
condi := Condi{"Username": "zzz", "Departname": ""}
cnt, err = testEngine.Table(&user).Id(ori.Uid).Update(&condi)
cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil {
t.Error(err)
panic(err)
@ -864,7 +874,7 @@ func TestUpdateSameMapper(t *testing.T) {
panic(err)
}
cnt, err = testEngine.Id(a.Id).Update(&Article{Name: "6"})
cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil {
t.Error(err)
panic(err)
@ -891,14 +901,14 @@ func TestUpdateSameMapper(t *testing.T) {
}
col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.Id(col2.Id).AllCols().Update(col2)
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
col3 := &UpdateAllCols{}
has, err = testEngine.Id(col2.Id).Get(col3)
has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
@ -935,14 +945,14 @@ func TestUpdateSameMapper(t *testing.T) {
col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.ColumnMapper.Obj2Table("Bool")
stringStr := testEngine.ColumnMapper.Obj2Table("String")
_, err = testEngine.Id(col2.Id).MustCols(boolStr, stringStr).Update(col2)
_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
col3 := &UpdateMustCols{}
has, err := testEngine.Id(col2.Id).Get(col3)
has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
@ -978,7 +988,7 @@ func TestUpdateSameMapper(t *testing.T) {
panic(err)
}
cnt, err := testEngine.Id(col1.Id).Incr("`Cnt`").Update(col1)
cnt, err := testEngine.ID(col1.Id).Incr("`Cnt`").Update(col1)
if err != nil {
t.Error(err)
panic(err)
@ -990,7 +1000,7 @@ func TestUpdateSameMapper(t *testing.T) {
}
newCol := new(UpdateIncr)
has, err := testEngine.Id(col1.Id).Get(newCol)
has, err := testEngine.ID(col1.Id).Get(newCol)
if err != nil {
t.Error(err)
panic(err)
@ -1093,3 +1103,53 @@ func TestBool(t *testing.T) {
}
}
}
func TestNoUpdate(t *testing.T) {
assert.NoError(t, prepareEngine())
type NoUpdate struct {
Id int64
Content string
}
assertSync(t, new(NoUpdate))
_, err := testEngine.Insert(&NoUpdate{
Content: "test",
})
assert.NoError(t, err)
_, err = testEngine.ID(1).Update(&NoUpdate{})
assert.Error(t, err)
assert.EqualValues(t, "No content found to be updated", err.Error())
}
func TestNewUpdate(t *testing.T) {
assert.NoError(t, prepareEngine())
type TbUserInfo struct {
Id int64 `xorm:"pk autoincr unique BIGINT" json:"id"`
Phone string `xorm:"not null unique VARCHAR(20)" json:"phone"`
UserName string `xorm:"VARCHAR(20)" json:"user_name"`
Gender int `xorm:"default 0 INTEGER" json:"gender"`
Pw string `xorm:"VARCHAR(100)" json:"pw"`
Token string `xorm:"TEXT" json:"token"`
Avatar string `xorm:"TEXT" json:"avatar"`
Extras interface{} `xorm:"JSON" json:"extras"`
Created time.Time `xorm:"DATETIME created"`
Updated time.Time `xorm:"DATETIME updated"`
Deleted time.Time `xorm:"DATETIME deleted"`
}
assertSync(t, new(TbUserInfo))
targetUsr := TbUserInfo{Phone: "13126564922"}
changeUsr := TbUserInfo{Token: "ABCDEFG"}
af, err := testEngine.Update(&changeUsr, &targetUsr)
assert.NoError(t, err)
assert.EqualValues(t, 0, af)
af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", 13126564922).Update(&changeUsr)
assert.NoError(t, err)
assert.EqualValues(t, 0, af)
}

View File

@ -272,6 +272,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface())
if fieldType == nil {
continue
}
requiredField := useAllCols
includeNil := useAllCols
@ -592,6 +595,22 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
return newColumns
}
func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, 0, len(statement.columnMap))
for col := range statement.columnMap {
fields := strings.Split(strings.TrimSpace(col), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
return newColumns
}
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
@ -618,7 +637,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
statement.columnMap[strings.ToLower(nc)] = true
}
newColumns := statement.col2NewColsWithQuote(columns...)
newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement
@ -890,17 +909,24 @@ func (statement *Statement) buildConds(table *core.Table, bean interface{}, incl
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
}
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
func (statement *Statement) mergeConds(bean interface{}) error {
if !statement.noAutoCondition {
var addedTableName = (len(statement.JoinStr) > 0)
autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil {
return "", nil, err
return err
}
statement.cond = statement.cond.And(autoCond)
}
if err := statement.processIDParam(); err != nil {
return err
}
return nil
}
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
@ -940,14 +966,12 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
columnStr = "*"
}
var condSQL string
var condArgs []interface{}
var err error
if isStruct {
condSQL, condArgs, err = statement.genConds(bean)
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
@ -960,10 +984,16 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) {
statement.setRefValue(rValue(bean))
condSQL, condArgs, err := statement.genConds(bean)
func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
var condSQL string
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.setRefValue(rValue(beans[0]))
condSQL, condArgs, err = statement.genConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil {
return "", nil, err
}

View File

@ -86,7 +86,7 @@ func TestExtends(t *testing.T) {
}
tu3 := &tempUser2{tempUser{0, "extends update"}, ""}
_, err = testEngine.Id(tu2.TempUser.Id).Update(tu3)
_, err = testEngine.ID(tu2.TempUser.Id).Update(tu3)
if err != nil {
t.Error(err)
panic(err)
@ -124,7 +124,7 @@ func TestExtends(t *testing.T) {
}
tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}}
_, err = testEngine.Id(tu9.TempUser2.TempUser.Id).Update(tu10)
_, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10)
if err != nil {
t.Error(err)
panic(err)
@ -168,7 +168,7 @@ func TestExtends(t *testing.T) {
}
tu6 := &tempUser3{&tempUser{0, "extends update"}, ""}
_, err = testEngine.Id(tu5.Temp.Id).Update(tu6)
_, err = testEngine.ID(tu5.Temp.Id).Update(tu6)
if err != nil {
t.Error(err)
panic(err)

View File

@ -51,7 +51,7 @@ func TestCreatedAndUpdated(t *testing.T) {
}
u.Name = "xxx"
cnt, err = testEngine.Id(u.Id).Update(u)
cnt, err = testEngine.ID(u.Id).Update(u)
if err != nil {
t.Error(err)
panic(err)
@ -110,48 +110,62 @@ func TestStrangeName(t *testing.T) {
}
}
type CreatedUpdated struct {
Id int64
Name string
Value float64 `xorm:"numeric"`
Created time.Time `xorm:"created"`
Created2 time.Time `xorm:"created"`
Updated time.Time `xorm:"updated"`
}
func TestCreatedUpdated(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.Sync(&CreatedUpdated{})
if err != nil {
t.Error(err)
panic(err)
type CreatedUpdated struct {
Id int64
Name string
Value float64 `xorm:"numeric"`
Created time.Time `xorm:"created"`
Created2 time.Time `xorm:"created"`
Updated time.Time `xorm:"updated"`
}
err := testEngine.Sync(&CreatedUpdated{})
assert.NoError(t, err)
c := &CreatedUpdated{Name: "test"}
_, err = testEngine.Insert(c)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
c2 := new(CreatedUpdated)
has, err := testEngine.Id(c.Id).Get(c2)
if err != nil {
t.Error(err)
panic(err)
}
has, err := testEngine.ID(c.Id).Get(c2)
assert.NoError(t, err)
if !has {
panic(errors.New("no id"))
}
assert.True(t, has)
c2.Value -= 1
_, err = testEngine.Id(c2.Id).Update(c2)
if err != nil {
t.Error(err)
panic(err)
_, err = testEngine.ID(c2.Id).Update(c2)
assert.NoError(t, err)
}
func TestCreatedUpdatedInt64(t *testing.T) {
assert.NoError(t, prepareEngine())
type CreatedUpdatedInt64 struct {
Id int64
Name string
Value float64 `xorm:"numeric"`
Created int64 `xorm:"created"`
Created2 int64 `xorm:"created"`
Updated int64 `xorm:"updated"`
}
assertSync(t, &CreatedUpdatedInt64{})
c := &CreatedUpdatedInt64{Name: "test"}
_, err := testEngine.Insert(c)
assert.NoError(t, err)
c2 := new(CreatedUpdatedInt64)
has, err := testEngine.ID(c.Id).Get(c2)
assert.NoError(t, err)
assert.True(t, has)
c2.Value -= 1
_, err = testEngine.ID(c2.Id).Update(c2)
assert.NoError(t, err)
}
type Lowercase struct {
@ -270,3 +284,77 @@ func TestTagComment(t *testing.T) {
assert.EqualValues(t, 1, len(tables[0].Columns()))
assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment)
}
func TestTagDefault(t *testing.T) {
assert.NoError(t, prepareEngine())
type DefaultStruct struct {
Id int64
Name string
Age int `xorm:"default(10)"`
}
assertSync(t, new(DefaultStruct))
cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{
Name: "test",
Age: 20,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s DefaultStruct
has, err := testEngine.ID(1).Get(&s)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 10, s.Age)
assert.EqualValues(t, "test", s.Name)
}
func TestTagsDirection(t *testing.T) {
assert.NoError(t, prepareEngine())
type OnlyFromDBStruct struct {
Id int64
Name string
Uuid string `xorm:"<- default '1'"`
}
assertSync(t, new(OnlyFromDBStruct))
cnt, err := testEngine.Insert(&OnlyFromDBStruct{
Name: "test",
Uuid: "2",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s OnlyFromDBStruct
has, err := testEngine.ID(1).Get(&s)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "1", s.Uuid)
assert.EqualValues(t, "test", s.Name)
type OnlyToDBStruct struct {
Id int64
Name string
Uuid string `xorm:"->"`
}
assertSync(t, new(OnlyToDBStruct))
cnt, err = testEngine.Insert(&OnlyToDBStruct{
Name: "test",
Uuid: "2",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s2 OnlyToDBStruct
has, err = testEngine.ID(1).Get(&s2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "", s2.Uuid)
assert.EqualValues(t, "test", s2.Name)
}

View File

@ -49,7 +49,7 @@ func TestVersion1(t *testing.T) {
}
newVer := new(VersionS)
has, err := testEngine.Id(ver.Id).Get(newVer)
has, err := testEngine.ID(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
panic(err)
@ -67,7 +67,7 @@ func TestVersion1(t *testing.T) {
}
newVer.Name = "-------"
_, err = testEngine.Id(ver.Id).Update(newVer)
_, err = testEngine.ID(ver.Id).Update(newVer)
if err != nil {
t.Error(err)
panic(err)
@ -78,7 +78,7 @@ func TestVersion1(t *testing.T) {
}
newVer = new(VersionS)
has, err = testEngine.Id(ver.Id).Get(newVer)
has, err = testEngine.ID(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
panic(err)

View File

@ -176,7 +176,7 @@ func TestNullStructUpdate(t *testing.T) {
item.Age = sql.NullInt64{23, true}
item.Height = sql.NullFloat64{0, false} // update to NULL
affected, err := testEngine.Id(2).Cols("age", "height", "is_man").Update(item)
affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item)
if err != nil {
t.Error(err)
panic(err)
@ -224,7 +224,7 @@ func TestNullStructUpdate(t *testing.T) {
// IsMan: sql.NullBool{true, true},
}
_, err := testEngine.AllCols().Id(6).Update(item)
_, err := testEngine.AllCols().ID(6).Update(item)
if err != nil {
t.Error(err)
panic(err)
@ -268,7 +268,7 @@ func TestNullStructFind(t *testing.T) {
if true {
item := new(NullType)
has, err := testEngine.Id(1).Get(item)
has, err := testEngine.ID(1).Get(item)
if err != nil {
t.Error(err)
panic(err)
@ -305,7 +305,7 @@ func TestNullStructFind(t *testing.T) {
if true {
item := make([]NullType, 0)
err := testEngine.Id(2).Find(&item)
err := testEngine.ID(2).Find(&item)
if err != nil {
t.Error(err)
panic(err)
@ -390,7 +390,7 @@ func TestNullStructDelete(t *testing.T) {
item := new(NullType)
_, err := testEngine.Id(1).Delete(item)
_, err := testEngine.ID(1).Delete(item)
if err != nil {
t.Error(err)
panic(err)

View File

@ -37,7 +37,7 @@ func TestArrayField(t *testing.T) {
assert.EqualValues(t, 1, cnt)
var arr ArrayStruct
has, err := testEngine.Id(1).Get(&arr)
has, err := testEngine.ID(1).Get(&arr)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, as.Name, arr.Name)
@ -320,7 +320,7 @@ func TestCustomType2(t *testing.T) {
assert.NoError(t, err)
user := UserCus{}
exist, err := testEngine.Id(1).Get(&user)
exist, err := testEngine.ID(1).Get(&user)
assert.NoError(t, err)
assert.True(t, exist)

View File

@ -17,7 +17,7 @@ import (
const (
// Version show the xorm's version
Version string = "0.6.2.0605"
Version string = "0.6.3.0713"
)
func regDrvsNDialects() bool {
@ -50,10 +50,13 @@ func close(engine *Engine) {
engine.Close()
}
func init() {
regDrvsNDialects()
}
// NewEngine new a db manager according to the parameter. Currently support four
// drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
regDrvsNDialects()
driver := core.QueryDriver(driverName)
if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)

View File

@ -12,6 +12,7 @@ import (
"github.com/go-xorm/core"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
)
var (

View File

@ -496,7 +496,7 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
func (c *context) File(file string) (err error) {
f, err := os.Open(file)
if err != nil {
return ErrNotFound
return NotFoundHandler(c)
}
defer f.Close()
@ -505,7 +505,7 @@ func (c *context) File(file string) (err error) {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
if err != nil {
return ErrNotFound
return NotFoundHandler(c)
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
@ -525,7 +525,7 @@ func (c *context) Inline(file, name string) (err error) {
}
func (c *context) contentDisposition(file, name, dispositionType string) (err error) {
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%s", dispositionType, name))
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name))
c.File(file)
return
}

View File

@ -187,7 +187,7 @@ func TestContext(t *testing.T) {
err = c.Attachment("_fixture/images/walle.png", "walle.png")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len())
}
@ -197,7 +197,7 @@ func TestContext(t *testing.T) {
err = c.Inline("_fixture/images/walle.png", "walle.png")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "inline; filename=walle.png", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len())
}

View File

@ -282,7 +282,7 @@ func New() (e *Echo) {
e.TLSServer.Handler = e
e.HTTPErrorHandler = e.DefaultHTTPErrorHandler
e.Binder = &DefaultBinder{}
e.Logger.SetLevel(log.OFF)
e.Logger.SetLevel(log.ERROR)
e.stdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0)
e.pool.New = func() interface{} {
return e.NewContext(nil, nil)
@ -295,7 +295,7 @@ func New() (e *Echo) {
func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context {
return &context{
request: r,
response: &Response{echo: e, Writer: w},
response: NewResponse(w, e),
store: make(Map),
echo: e,
pvalues: make([]string, *e.maxParam),

View File

@ -21,66 +21,66 @@ func (g *Group) Use(middleware ...MiddlewareFunc) {
// Allow all requests to reach the group as they might get dropped if router
// doesn't find a match, making none of the group middleware process.
g.echo.Any(path.Clean(g.prefix+"/*"), func(c Context) error {
return ErrNotFound
return NotFoundHandler(c)
}, g.middleware...)
}
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(CONNECT, path, h, m...)
return g.Add(CONNECT, path, h, m...)
}
// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(DELETE, path, h, m...)
return g.Add(DELETE, path, h, m...)
}
// GET implements `Echo#GET()` for sub-routes within the Group.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(GET, path, h, m...)
return g.Add(GET, path, h, m...)
}
// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(HEAD, path, h, m...)
return g.Add(HEAD, path, h, m...)
}
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(OPTIONS, path, h, m...)
return g.Add(OPTIONS, path, h, m...)
}
// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(PATCH, path, h, m...)
return g.Add(PATCH, path, h, m...)
}
// POST implements `Echo#POST()` for sub-routes within the Group.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(POST, path, h, m...)
return g.Add(POST, path, h, m...)
}
// PUT implements `Echo#PUT()` for sub-routes within the Group.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(PUT, path, h, m...)
return g.Add(PUT, path, h, m...)
}
// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(TRACE, path, h, m...)
return g.Add(TRACE, path, h, m...)
}
// Any implements `Echo#Any()` for sub-routes within the Group.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
for _, m := range methods {
g.add(m, path, handler, middleware...)
g.Add(m, path, handler, middleware...)
}
}
// Match implements `Echo#Match()` for sub-routes within the Group.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
for _, m := range methods {
g.add(m, path, handler, middleware...)
g.Add(m, path, handler, middleware...)
}
}
@ -102,7 +102,8 @@ func (g *Group) File(path, file string) {
g.echo.File(g.prefix+path, file)
}
func (g *Group) add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
// Add implements `Echo#Add()` for sub-routes within the Group.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
// Combine into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.

View File

@ -141,7 +141,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.keyFunc)
} else {
claims := reflect.ValueOf(config.Claims).Interface().(jwt.Claims)
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
}
if err == nil && token.Valid {

View File

@ -22,6 +22,42 @@ type jwtCustomClaims struct {
jwtCustomInfo
}
func TestJWTRace(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
validKey := []byte("secret")
h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: validKey,
})(handler)
makeReq := func(token string) echo.Context {
req := httptest.NewRequest(echo.GET, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
c := e.NewContext(req, res)
assert.NoError(t, h(c))
return c
}
c := makeReq(initialToken)
user := c.Get("user").(*jwt.Token)
claims := user.Claims.(*jwtCustomClaims)
assert.Equal(t, claims.Name, "John Doe")
makeReq(raceToken)
user = c.Get("user").(*jwt.Token)
claims = user.Claims.(*jwtCustomClaims)
// Initial context should still be "John Doe", not "Race Condition"
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
}
func TestJWT(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {

View File

@ -142,6 +142,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
tgt := config.Balancer.Next()

View File

@ -20,6 +20,11 @@ type (
}
)
// NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
return &Response{Writer: w, echo: e}
}
// Header returns the header map for the writer that will be sent by
// WriteHeader. Changing the header after a call to WriteHeader (or Write) has
// no effect unless the modified headers were declared as trailers by setting

View File

@ -1,4 +1,4 @@
// +build go1.7,!go1.8
// +build go1.7, !go1.8
package echo

6
vendor/github.com/lib/pq/array.go generated vendored
View File

@ -13,7 +13,7 @@ import (
var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension.
@ -278,7 +278,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b
// TODO calculate the assign function for other types
// TODO repeat this section on the element type of arrays or slices (multidimensional)
{
if reflect.PtrTo(rt).Implements(typeSqlScanner) {
if reflect.PtrTo(rt).Implements(typeSQLScanner) {
// dest is always addressable because it is an element of a slice.
assign = func(src []byte, dest reflect.Value) (err error) {
ss := dest.Addr().Interface().(sql.Scanner)
@ -587,7 +587,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
}
}
var del string = ","
var del = ","
var err error
var iv interface{} = rv.Interface()

83
vendor/github.com/lib/pq/conn.go generated vendored
View File

@ -27,12 +27,12 @@ var (
ErrNotSupported = errors.New("pq: Unsupported command")
ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
errUnexpectedReady = errors.New("unexpected ReadyForQuery")
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertId = errors.New("no LastInsertId available after the empty statement")
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
)
type Driver struct{}
@ -131,7 +131,7 @@ type conn struct {
}
// Handle driver-side settings in parsed connection string.
func (c *conn) handleDriverSettings(o values) (err error) {
func (cn *conn) handleDriverSettings(o values) (err error) {
boolSetting := func(key string, val *bool) error {
if value, ok := o[key]; ok {
if value == "yes" {
@ -145,18 +145,18 @@ func (c *conn) handleDriverSettings(o values) (err error) {
return nil
}
err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult)
err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
if err != nil {
return err
}
err = boolSetting("binary_parameters", &c.binaryParameters)
err = boolSetting("binary_parameters", &cn.binaryParameters)
if err != nil {
return err
}
return nil
}
func (c *conn) handlePgpass(o values) {
func (cn *conn) handlePgpass(o values) {
// if a password was supplied, do not process .pgpass
if _, ok := o["password"]; ok {
return
@ -229,10 +229,10 @@ func (c *conn) handlePgpass(o values) {
}
}
func (c *conn) writeBuf(b byte) *writeBuf {
c.scratch[0] = b
func (cn *conn) writeBuf(b byte) *writeBuf {
cn.scratch[0] = b
return &writeBuf{
buf: c.scratch[:5],
buf: cn.scratch[:5],
pos: 1,
}
}
@ -310,9 +310,8 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
u, err := userCurrent()
if err != nil {
return nil, err
} else {
o["user"] = u
}
o["user"] = u
}
cn := &conn{
@ -698,7 +697,7 @@ var emptyRows noRows
var _ driver.Result = noRows{}
func (noRows) LastInsertId() (int64, error) {
return 0, errNoLastInsertId
return 0, errNoLastInsertID
}
func (noRows) RowsAffected() (int64, error) {
@ -840,16 +839,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
} else {
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
}
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
}
// Implement the optional "Execer" interface for one-shot queries
@ -876,17 +874,16 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
cn.postExecuteWorkaround()
res, _, err = cn.readExecuteResponse("Execute")
return res, err
} else {
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st := cn.prepareTo(query, "")
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
}
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st := cn.prepareTo(query, "")
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
}
func (cn *conn) send(m *writeBuf) {
@ -1147,10 +1144,10 @@ const formatText format = 0
const formatBinary format = 1
// One result-column format code with the value 1 (i.e. all binary).
var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1}
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
// No result-column format codes (i.e. all text).
var colFmtDataAllText []byte = []byte{0, 0}
var colFmtDataAllText = []byte{0, 0}
type stmt struct {
cn *conn
@ -1515,7 +1512,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
cn.send(b)
}
func (c *conn) processParameterStatus(r *readBuf) {
func (cn *conn) processParameterStatus(r *readBuf) {
var err error
param := r.string()
@ -1526,13 +1523,13 @@ func (c *conn) processParameterStatus(r *readBuf) {
var minor int
_, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
if err == nil {
c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
}
case "TimeZone":
c.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
if err != nil {
c.parameterStatus.currentLocation = nil
cn.parameterStatus.currentLocation = nil
}
default:
@ -1540,8 +1537,8 @@ func (c *conn) processParameterStatus(r *readBuf) {
}
}
func (c *conn) processReadyForQuery(r *readBuf) {
c.txnStatus = transactionStatus(r.byte())
func (cn *conn) processReadyForQuery(r *readBuf) {
cn.txnStatus = transactionStatus(r.byte())
}
func (cn *conn) readReadyForQuery() {
@ -1556,9 +1553,9 @@ func (cn *conn) readReadyForQuery() {
}
}
func (c *conn) processBackendKeyData(r *readBuf) {
c.processID = r.int32()
c.secretKey = r.int32()
func (cn *conn) processBackendKeyData(r *readBuf) {
cn.processID = r.int32()
cn.secretKey = r.int32()
}
func (cn *conn) readParseResponse() {

View File

@ -136,7 +136,7 @@ func TestOpenURL(t *testing.T) {
testURL("postgresql://")
}
const pgpass_file = "/tmp/pqgotest_pgpass"
const pgpassFile = "/tmp/pqgotest_pgpass"
func TestPgpass(t *testing.T) {
if os.Getenv("TRAVIS") != "true" {
@ -172,10 +172,10 @@ func TestPgpass(t *testing.T) {
txn.Rollback()
}
testAssert("", "ok", "missing .pgpass, unexpected error %#v")
os.Setenv("PGPASSFILE", pgpass_file)
os.Setenv("PGPASSFILE", pgpassFile)
testAssert("host=/tmp", "fail", ", unexpected error %#v")
os.Remove(pgpass_file)
pgpass, err := os.OpenFile(pgpass_file, os.O_RDWR|os.O_CREATE, 0644)
os.Remove(pgpassFile)
pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
t.Fatalf("Unexpected error writing pgpass file %#v", err)
}
@ -213,7 +213,7 @@ localhost:*:*:*:pass_C
// wrong permissions for the pgpass file means it should be ignored
assertPassword(values{"host": "example.com", "user": "foo"}, "")
// fix the permissions and check if it has taken effect
os.Chmod(pgpass_file, 0600)
os.Chmod(pgpassFile, 0600)
assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A")
assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback")
assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B")
@ -221,7 +221,7 @@ localhost:*:*:*:pass_C
assertPassword(values{"host": "", "user": "some_user"}, "pass_C")
assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C")
// cleanup
os.Remove(pgpass_file)
os.Remove(pgpassFile)
os.Setenv("PGPASSFILE", "")
}
@ -393,8 +393,8 @@ func TestEmptyQuery(t *testing.T) {
if _, err := res.RowsAffected(); err != errNoRowsAffected {
t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
}
if _, err := res.LastInsertId(); err != errNoLastInsertId {
t.Fatalf("expected %s, got %v", errNoLastInsertId, err)
if _, err := res.LastInsertId(); err != errNoLastInsertID {
t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
}
rows, err := db.Query("")
if err != nil {
@ -425,8 +425,8 @@ func TestEmptyQuery(t *testing.T) {
if _, err := res.RowsAffected(); err != errNoRowsAffected {
t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
}
if _, err := res.LastInsertId(); err != errNoLastInsertId {
t.Fatalf("expected %s, got %v", errNoLastInsertId, err)
if _, err := res.LastInsertId(); err != errNoLastInsertID {
t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
}
rows, err = stmt.Query()
if err != nil {
@ -1053,16 +1053,16 @@ func TestIssue282(t *testing.T) {
db := openTestConn(t)
defer db.Close()
var search_path string
var searchPath string
err := db.QueryRow(`
SET LOCAL search_path TO pg_catalog;
SET LOCAL search_path TO pg_catalog;
SHOW search_path`).Scan(&search_path)
SHOW search_path`).Scan(&searchPath)
if err != nil {
t.Fatal(err)
}
if search_path != "pg_catalog" {
t.Fatalf("unexpected search_path %s", search_path)
if searchPath != "pg_catalog" {
t.Fatalf("unexpected search_path %s", searchPath)
}
}

View File

@ -370,17 +370,17 @@ func TestInfinityTimestamp(t *testing.T) {
t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String())
}
y_1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC)
ym1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC)
y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC)
var s string
err = db.QueryRow("SELECT $1::timestamp::text", y_1500).Scan(&s)
err = db.QueryRow("SELECT $1::timestamp::text", ym1500).Scan(&s)
if err != nil {
t.Errorf("Encoding -infinity, expected no error, got %q", err)
}
if s != "-infinity" {
t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s)
}
err = db.QueryRow("SELECT $1::timestamptz::text", y_1500).Scan(&s)
err = db.QueryRow("SELECT $1::timestamptz::text", ym1500).Scan(&s)
if err != nil {
t.Errorf("Encoding -infinity, expected no error, got %q", err)
}

View File

@ -33,7 +33,7 @@ func TestDecodeUUIDBackend(t *testing.T) {
db := openTestConn(t)
defer db.Close()
var s string = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a"
var s = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a"
var scanned interface{}
err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned)

View File

@ -29,6 +29,15 @@ const (
backgroundMask = (backgroundRed | backgroundBlue | backgroundGreen | backgroundIntensity)
)
const (
genericRead = 0x80000000
genericWrite = 0x40000000
)
const (
consoleTextmodeBuffer = 0x1
)
type wchar uint16
type short int16
type dword uint32
@ -69,14 +78,17 @@ var (
procGetConsoleCursorInfo = kernel32.NewProc("GetConsoleCursorInfo")
procSetConsoleCursorInfo = kernel32.NewProc("SetConsoleCursorInfo")
procSetConsoleTitle = kernel32.NewProc("SetConsoleTitleW")
procCreateConsoleScreenBuffer = kernel32.NewProc("CreateConsoleScreenBuffer")
)
// Writer provide colorable Writer to the console
type Writer struct {
out io.Writer
handle syscall.Handle
oldattr word
oldpos coord
out io.Writer
handle syscall.Handle
althandle syscall.Handle
oldattr word
oldpos coord
rest bytes.Buffer
}
// NewColorable return new instance of Writer which handle escape sequence from File.
@ -407,7 +419,18 @@ func (w *Writer) Write(data []byte) (n int, err error) {
var csbi consoleScreenBufferInfo
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
er := bytes.NewReader(data)
handle := w.handle
var er *bytes.Reader
if w.rest.Len() > 0 {
var rest bytes.Buffer
w.rest.WriteTo(&rest)
w.rest.Reset()
rest.Write(data)
er = bytes.NewReader(rest.Bytes())
} else {
er = bytes.NewReader(data)
}
var bw [1]byte
loop:
for {
@ -426,28 +449,42 @@ loop:
}
if c2 == ']' {
if err := doTitleSequence(er); err != nil {
w.rest.WriteByte(c1)
w.rest.WriteByte(c2)
er.WriteTo(&w.rest)
if bytes.IndexByte(w.rest.Bytes(), 0x07) == -1 {
break loop
}
er = bytes.NewReader(w.rest.Bytes()[2:])
err := doTitleSequence(er)
if err != nil {
break loop
}
w.rest.Reset()
continue
}
if c2 != 0x5b {
continue
}
w.rest.WriteByte(c1)
w.rest.WriteByte(c2)
er.WriteTo(&w.rest)
var buf bytes.Buffer
var m byte
for {
c, err := er.ReadByte()
if err != nil {
break loop
}
for i, c := range w.rest.Bytes()[2:] {
if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' {
m = c
er = bytes.NewReader(w.rest.Bytes()[2+i+1:])
w.rest.Reset()
break
}
buf.Write([]byte(string(c)))
}
if m == 0 {
break loop
}
switch m {
case 'A':
@ -455,61 +492,64 @@ loop:
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.y -= short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'B':
n, err = strconv.Atoi(buf.String())
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.y += short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'C':
n, err = strconv.Atoi(buf.String())
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x += short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'D':
n, err = strconv.Atoi(buf.String())
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x -= short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
if csbi.cursorPosition.x < 0 {
csbi.cursorPosition.x = 0
}
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'E':
n, err = strconv.Atoi(buf.String())
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x = 0
csbi.cursorPosition.y += short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'F':
n, err = strconv.Atoi(buf.String())
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x = 0
csbi.cursorPosition.y -= short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'G':
n, err = strconv.Atoi(buf.String())
if err != nil {
continue
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x = short(n - 1)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'H':
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'H', 'f':
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
if buf.Len() > 0 {
token := strings.Split(buf.String(), ";")
switch len(token) {
@ -534,7 +574,7 @@ loop:
} else {
csbi.cursorPosition.y = 0
}
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'J':
n := 0
if buf.Len() > 0 {
@ -545,7 +585,7 @@ loop:
}
var count, written dword
var cursor coord
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
switch n {
case 0:
cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y}
@ -557,8 +597,8 @@ loop:
cursor = coord{x: csbi.window.left, y: csbi.window.top}
count = dword(csbi.size.x - csbi.cursorPosition.x + (csbi.size.y-csbi.cursorPosition.y)*csbi.size.x)
}
procFillConsoleOutputCharacter.Call(uintptr(w.handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputAttribute.Call(uintptr(w.handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputCharacter.Call(uintptr(handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputAttribute.Call(uintptr(handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
case 'K':
n := 0
if buf.Len() > 0 {
@ -567,26 +607,28 @@ loop:
continue
}
}
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
var cursor coord
var count, written dword
switch n {
case 0:
cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y}
count = dword(csbi.size.x - csbi.cursorPosition.x)
case 1:
cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y}
count = dword(csbi.size.x - csbi.cursorPosition.x)
case 2:
cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y}
count = dword(csbi.size.x)
}
var count, written dword
count = dword(csbi.size.x - csbi.cursorPosition.x)
procFillConsoleOutputCharacter.Call(uintptr(w.handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputAttribute.Call(uintptr(w.handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputCharacter.Call(uintptr(handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputAttribute.Call(uintptr(handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
case 'm':
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
attr := csbi.attributes
cs := buf.String()
if cs == "" {
procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(w.oldattr))
procSetConsoleTextAttribute.Call(uintptr(handle), uintptr(w.oldattr))
continue
}
token := strings.Split(cs, ";")
@ -625,6 +667,21 @@ loop:
attr |= n256foreAttr[n256]
i += 2
}
} else if len(token) == 5 && token[i+1] == "2" {
var r, g, b int
r, _ = strconv.Atoi(token[i+2])
g, _ = strconv.Atoi(token[i+3])
b, _ = strconv.Atoi(token[i+4])
i += 4
if r > 127 {
attr |= foregroundRed
}
if g > 127 {
attr |= foregroundGreen
}
if b > 127 {
attr |= foregroundBlue
}
} else {
attr = attr & (w.oldattr & backgroundMask)
}
@ -652,6 +709,21 @@ loop:
attr |= n256backAttr[n256]
i += 2
}
} else if len(token) == 5 && token[i+1] == "2" {
var r, g, b int
r, _ = strconv.Atoi(token[i+2])
g, _ = strconv.Atoi(token[i+3])
b, _ = strconv.Atoi(token[i+4])
i += 4
if r > 127 {
attr |= backgroundRed
}
if g > 127 {
attr |= backgroundGreen
}
if b > 127 {
attr |= backgroundBlue
}
} else {
attr = attr & (w.oldattr & foregroundMask)
}
@ -683,30 +755,52 @@ loop:
attr |= backgroundBlue
}
}
procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(attr))
procSetConsoleTextAttribute.Call(uintptr(handle), uintptr(attr))
}
}
case 'h':
var ci consoleCursorInfo
cs := buf.String()
if cs == "?25" {
var ci consoleCursorInfo
procGetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci)))
if cs == "5>" {
procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
ci.visible = 0
procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?25" {
procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
ci.visible = 1
procSetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci)))
procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?1049" {
if w.althandle == 0 {
h, _, _ := procCreateConsoleScreenBuffer.Call(uintptr(genericRead|genericWrite), 0, 0, uintptr(consoleTextmodeBuffer), 0, 0)
w.althandle = syscall.Handle(h)
if w.althandle != 0 {
handle = w.althandle
}
}
}
case 'l':
var ci consoleCursorInfo
cs := buf.String()
if cs == "?25" {
var ci consoleCursorInfo
procGetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci)))
if cs == "5>" {
procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
ci.visible = 1
procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?25" {
procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
ci.visible = 0
procSetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci)))
procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?1049" {
if w.althandle != 0 {
syscall.CloseHandle(w.althandle)
w.althandle = 0
handle = w.handle
}
}
case 's':
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
w.oldpos = csbi.cursorPosition
case 'u':
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&w.oldpos)))
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&w.oldpos)))
}
}

View File

@ -65,7 +65,7 @@ FAQ
* Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
* Can I use this in multiple routines concurrently?

View File

@ -167,7 +167,7 @@ type SQLiteDriver struct {
// SQLiteConn implement sql.Conn.
type SQLiteConn struct {
dbMu sync.Mutex
mu sync.Mutex
db *C.sqlite3
loc *time.Location
txlock string
@ -182,6 +182,7 @@ type SQLiteTx struct {
// SQLiteStmt implement sql.Stmt.
type SQLiteStmt struct {
mu sync.Mutex
c *SQLiteConn
s *C.sqlite3_stmt
t string
@ -202,6 +203,7 @@ type SQLiteRows struct {
cols []string
decltype []string
cls bool
closed bool
done chan struct{}
}
@ -761,9 +763,9 @@ func (c *SQLiteConn) Close() error {
return c.lastError()
}
deleteHandles(c)
c.dbMu.Lock()
c.mu.Lock()
c.db = nil
c.dbMu.Unlock()
c.mu.Unlock()
runtime.SetFinalizer(c, nil)
return nil
}
@ -772,8 +774,8 @@ func (c *SQLiteConn) dbConnOpen() bool {
if c == nil {
return false
}
c.dbMu.Lock()
defer c.dbMu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
return c.db != nil
}
@ -802,6 +804,8 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
// Close the statement.
func (s *SQLiteStmt) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return nil
}
@ -810,6 +814,7 @@ func (s *SQLiteStmt) Close() error {
return errors.New("sqlite statement with already closed database connection")
}
rv := C.sqlite3_finalize(s.s)
s.s = nil
if rv != C.SQLITE_OK {
return s.c.lastError()
}
@ -866,10 +871,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if len(v) == 0 {
ln := len(v)
if ln == 0 {
v = placeHolder
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@ -904,6 +910,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
cols: nil,
decltype: nil,
cls: s.cls,
closed: false,
done: make(chan struct{}),
}
@ -976,25 +983,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
// Close the rows.
func (rc *SQLiteRows) Close() error {
if rc.s.closed {
rc.s.mu.Lock()
if rc.s.closed || rc.closed {
rc.s.mu.Unlock()
return nil
}
rc.closed = true
if rc.done != nil {
close(rc.done)
}
if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close()
}
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
rc.s.mu.Unlock()
return rc.s.c.lastError()
}
rc.s.mu.Unlock()
return nil
}
// Columns return column names.
func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
if rc.s.s != nil && rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
@ -1003,9 +1018,8 @@ func (rc *SQLiteRows) Columns() []string {
return rc.cols
}
// DeclTypes return column types.
func (rc *SQLiteRows) DeclTypes() []string {
if rc.decltype == nil {
func (rc *SQLiteRows) declTypes() []string {
if rc.s.s != nil && rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
@ -1014,8 +1028,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
return rc.decltype
}
// DeclTypes return column types.
func (rc *SQLiteRows) DeclTypes() []string {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
return rc.declTypes()
}
// Next move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
if rc.s.closed {
return io.EOF
}
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
@ -1028,7 +1054,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return nil
}
rc.DeclTypes()
rc.declTypes()
for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {

View File

@ -8,9 +8,13 @@
package sqlite3
import (
"context"
"database/sql"
"fmt"
"math/rand"
"os"
"testing"
"time"
)
func TestNamedParams(t *testing.T) {
@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) {
t.Error("Failed to db.QueryRow: not matched results")
}
}
var (
testTableStatements = []string{
`DROP TABLE IF EXISTS test_table`,
`
CREATE TABLE IF NOT EXISTS test_table (
key1 VARCHAR(64) PRIMARY KEY,
key_id VARCHAR(64) NOT NULL,
key2 VARCHAR(64) NOT NULL,
key3 VARCHAR(64) NOT NULL,
key4 VARCHAR(64) NOT NULL,
key5 VARCHAR(64) NOT NULL,
key6 VARCHAR(64) NOT NULL,
data BLOB NOT NULL
);`,
}
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
func randStringBytes(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func initDatabase(t *testing.T, db *sql.DB, rowCount int64) {
t.Logf("Executing db initializing statements")
for _, query := range testTableStatements {
_, err := db.Exec(query)
if err != nil {
t.Fatal(err)
}
}
for i := int64(0); i < rowCount; i++ {
query := `INSERT INTO test_table
(key1, key_id, key2, key3, key4, key5, key6, data)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?);`
args := []interface{}{
randStringBytes(50),
fmt.Sprint(i),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(2048),
}
_, err := db.Exec(query, args...)
if err != nil {
t.Fatal(err)
}
}
}
func TestShortTimeout(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared")
if err != nil {
t.Fatal(err)
}
defer db.Close()
initDatabase(t, db, 10000)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond)
defer cancel()
query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
FROM test_table
ORDER BY key2 ASC`
rows, err := db.QueryContext(ctx, query)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var key1, keyid, key2, key3, key4, key5, key6 string
var data []byte
err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data)
if err != nil {
break
}
}
if context.DeadlineExceeded != ctx.Err() {
t.Fatal(ctx.Err())
}
}

View File

@ -10,5 +10,6 @@ package sqlite3
#cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo linux LDFLAGS: -lsqlite3
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
#cgo solaris LDFLAGS: -lsqlite3
*/
import "C"

View File

@ -9,5 +9,6 @@ package sqlite3
/*
#cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl
#cgo solaris LDFLAGS: -lc
*/
import "C"

View File

@ -6,21 +6,22 @@
package sqlite3
import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/url"
"os"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/mattn/go-sqlite3/sqlite3_test"
)
func TempFilename(t *testing.T) string {
@ -870,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
}
}
func TestSuite(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
t.Fatal(err)
}
defer db.Close()
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
}
// TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82
func TestExecer(t *testing.T) {
@ -1355,6 +1344,61 @@ func TestUpdateAndTransactionHooks(t *testing.T) {
}
}
func TestNilAndEmptyBytes(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
actualNil := []byte("use this to use an actual nil not a reference to nil")
emptyBytes := []byte{}
for tsti, tst := range []struct {
name string
columnType string
insertBytes []byte
expectedBytes []byte
}{
{"actual nil blob", "blob", actualNil, nil},
{"referenced nil blob", "blob", nil, nil},
{"empty blob", "blob", emptyBytes, emptyBytes},
{"actual nil text", "text", actualNil, nil},
{"referenced nil text", "text", nil, nil},
{"empty text", "text", emptyBytes, emptyBytes},
} {
if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
t.Fatal(tst.name, err)
}
if bytes.Equal(tst.insertBytes, actualNil) {
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
t.Fatal(tst.name, err)
}
} else {
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
t.Fatal(tst.name, err)
}
}
rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
if err != nil {
t.Fatal(tst.name, err)
}
if !rows.Next() {
t.Fatal(tst.name, "no rows")
}
var scanBytes []byte
if err = rows.Scan(&scanBytes); err != nil {
t.Fatal(tst.name, err)
}
if err = rows.Err(); err != nil {
t.Fatal(tst.name, err)
}
if tst.expectedBytes == nil && scanBytes != nil {
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
} else if !bytes.Equal(scanBytes, tst.expectedBytes) {
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
}
}
}
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
@ -1389,3 +1433,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
}
}
}
func TestSuite(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
t.Fatal(err)
}
defer d.Close()
db = &TestDB{t, d, SQLITE, sync.Once{}}
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
if !testing.Short() {
for _, b := range benchmarks {
fmt.Printf("%-20s", b.Name)
r := testing.Benchmark(b.F)
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
}
}
db.tearDown()
}
// Dialect is a type of dialect of databases.
type Dialect int
// Dialects for databases.
const (
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
MYSQL // MYSQL mean MySQL dialect
)
// DB provide context for the tests
type TestDB struct {
*testing.T
*sql.DB
dialect Dialect
once sync.Once
}
var db *TestDB
// the following tables will be created and dropped during the test
var testTables = []string{"foo", "bar", "t", "bench"}
var tests = []testing.InternalTest{
{Name: "TestResult", F: testResult},
{Name: "TestBlobs", F: testBlobs},
{Name: "TestManyQueryRow", F: testManyQueryRow},
{Name: "TestTxQuery", F: testTxQuery},
{Name: "TestPreparedStmt", F: testPreparedStmt},
}
var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExec", F: benchmarkExec},
{Name: "BenchmarkQuery", F: benchmarkQuery},
{Name: "BenchmarkParams", F: benchmarkParams},
{Name: "BenchmarkStmt", F: benchmarkStmt},
{Name: "BenchmarkRows", F: benchmarkRows},
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
}
func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
res, err := db.Exec(sql, args...)
if err != nil {
db.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (db *TestDB) tearDown() {
for _, tbl := range testTables {
switch db.dialect {
case SQLITE:
db.mustExec("drop table if exists " + tbl)
case MYSQL, POSTGRESQL:
db.mustExec("drop table if exists " + tbl)
default:
db.Fatal("unknown dialect")
}
}
}
// q replaces ? parameters if needed
func (db *TestDB) q(sql string) string {
switch db.dialect {
case POSTGRESQL: // repace with $1, $2, ..
qrx := regexp.MustCompile(`\?`)
n := 0
return qrx.ReplaceAllStringFunc(sql, func(string) string {
n++
return "$" + strconv.Itoa(n)
})
}
return sql
}
func (db *TestDB) blobType(size int) string {
switch db.dialect {
case SQLITE:
return fmt.Sprintf("blob[%d]", size)
case POSTGRESQL:
return "bytea"
case MYSQL:
return fmt.Sprintf("VARBINARY(%d)", size)
}
panic("unknown dialect")
}
func (db *TestDB) serialPK() string {
switch db.dialect {
case SQLITE:
return "integer primary key autoincrement"
case POSTGRESQL:
return "serial primary key"
case MYSQL:
return "integer primary key auto_increment"
}
panic("unknown dialect")
}
func (db *TestDB) now() string {
switch db.dialect {
case SQLITE:
return "datetime('now')"
case POSTGRESQL:
return "now()"
case MYSQL:
return "now()"
}
panic("unknown dialect")
}
func makeBench() {
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
panic(err)
}
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
if err != nil {
panic(err)
}
defer st.Close()
for i := 0; i < 100; i++ {
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
panic(err)
}
}
}
// testResult is test for result
func testResult(t *testing.T) {
db.tearDown()
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
for i := 1; i < 3; i++ {
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
n, err := r.RowsAffected()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("got %v, want %v", n, 1)
}
n, err = r.LastInsertId()
if err != nil {
t.Fatal(err)
}
if n != int64(i) {
t.Errorf("got %v, want %v", n, i)
}
}
if _, err := db.Exec("error!"); err == nil {
t.Fatalf("expected error")
}
}
// testBlobs is test for blobs
func testBlobs(t *testing.T) {
db.tearDown()
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
// testManyQueryRow is test for many query row
func testManyQueryRow(t *testing.T) {
if testing.Short() {
t.Log("skipping in short mode")
return
}
db.tearDown()
db.mustExec("create table foo (id integer primary key, name varchar(50))")
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
// testTxQuery is test for transactional query
func testTxQuery(t *testing.T) {
db.tearDown()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
// testPreparedStmt is test for prepared statement
func testPreparedStmt(t *testing.T) {
db.tearDown()
db.mustExec("CREATE TABLE t (count INT)")
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}
// Benchmarks need to use panic() since b.Error errors are lost when
// running via testing.Benchmark() I would like to run these via go
// test -bench but calling Benchmark() from a benchmark test
// currently hangs go.
// benchmarkExec is benchmark for exec
func benchmarkExec(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := db.Exec("select 1"); err != nil {
panic(err)
}
}
}
// benchmarkQuery is benchmark for query
func benchmarkQuery(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// benchmarkParams is benchmark for params
func benchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// benchmarkStmt is benchmark for statement
func benchmarkStmt(b *testing.B) {
st, err := db.Prepare("select ?, ?, ?, ?")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// benchmarkRows is benchmark for rows
func benchmarkRows(b *testing.B) {
db.once.Do(makeBench)
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := db.Query("select * from bench")
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}
// benchmarkStmtRows is benchmark for statement rows
func benchmarkStmtRows(b *testing.B) {
db.once.Do(makeBench)
st, err := db.Prepare("select * from bench")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := st.Query()
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}

View File

@ -1,423 +0,0 @@
package sqlite3_test
import (
"database/sql"
"fmt"
"math/rand"
"regexp"
"strconv"
"sync"
"testing"
"time"
)
// Dialect is a type of dialect of databases.
type Dialect int
// Dialects for databases.
const (
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
MYSQL // MYSQL mean MySQL dialect
)
// DB provide context for the tests
type DB struct {
*testing.T
*sql.DB
dialect Dialect
once sync.Once
}
var db *DB
// the following tables will be created and dropped during the test
var testTables = []string{"foo", "bar", "t", "bench"}
var tests = []testing.InternalTest{
{Name: "TestBlobs", F: TestBlobs},
{Name: "TestManyQueryRow", F: TestManyQueryRow},
{Name: "TestTxQuery", F: TestTxQuery},
{Name: "TestPreparedStmt", F: TestPreparedStmt},
}
var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExec", F: BenchmarkExec},
{Name: "BenchmarkQuery", F: BenchmarkQuery},
{Name: "BenchmarkParams", F: BenchmarkParams},
{Name: "BenchmarkStmt", F: BenchmarkStmt},
{Name: "BenchmarkRows", F: BenchmarkRows},
{Name: "BenchmarkStmtRows", F: BenchmarkStmtRows},
}
// RunTests runs the SQL test suite
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
db = &DB{t, d, dialect, sync.Once{}}
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
if !testing.Short() {
for _, b := range benchmarks {
fmt.Printf("%-20s", b.Name)
r := testing.Benchmark(b.F)
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
}
}
db.tearDown()
}
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
res, err := db.Exec(sql, args...)
if err != nil {
db.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (db *DB) tearDown() {
for _, tbl := range testTables {
switch db.dialect {
case SQLITE:
db.mustExec("drop table if exists " + tbl)
case MYSQL, POSTGRESQL:
db.mustExec("drop table if exists " + tbl)
default:
db.Fatal("unknown dialect")
}
}
}
// q replaces ? parameters if needed
func (db *DB) q(sql string) string {
switch db.dialect {
case POSTGRESQL: // repace with $1, $2, ..
qrx := regexp.MustCompile(`\?`)
n := 0
return qrx.ReplaceAllStringFunc(sql, func(string) string {
n++
return "$" + strconv.Itoa(n)
})
}
return sql
}
func (db *DB) blobType(size int) string {
switch db.dialect {
case SQLITE:
return fmt.Sprintf("blob[%d]", size)
case POSTGRESQL:
return "bytea"
case MYSQL:
return fmt.Sprintf("VARBINARY(%d)", size)
}
panic("unknown dialect")
}
func (db *DB) serialPK() string {
switch db.dialect {
case SQLITE:
return "integer primary key autoincrement"
case POSTGRESQL:
return "serial primary key"
case MYSQL:
return "integer primary key auto_increment"
}
panic("unknown dialect")
}
func (db *DB) now() string {
switch db.dialect {
case SQLITE:
return "datetime('now')"
case POSTGRESQL:
return "now()"
case MYSQL:
return "now()"
}
panic("unknown dialect")
}
func makeBench() {
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
panic(err)
}
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
if err != nil {
panic(err)
}
defer st.Close()
for i := 0; i < 100; i++ {
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
panic(err)
}
}
}
// TestResult is test for result
func TestResult(t *testing.T) {
db.tearDown()
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
for i := 1; i < 3; i++ {
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
n, err := r.RowsAffected()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("got %v, want %v", n, 1)
}
n, err = r.LastInsertId()
if err != nil {
t.Fatal(err)
}
if n != int64(i) {
t.Errorf("got %v, want %v", n, i)
}
}
if _, err := db.Exec("error!"); err == nil {
t.Fatalf("expected error")
}
}
// TestBlobs is test for blobs
func TestBlobs(t *testing.T) {
db.tearDown()
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
// TestManyQueryRow is test for many query row
func TestManyQueryRow(t *testing.T) {
if testing.Short() {
t.Log("skipping in short mode")
return
}
db.tearDown()
db.mustExec("create table foo (id integer primary key, name varchar(50))")
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
// TestTxQuery is test for transactional query
func TestTxQuery(t *testing.T) {
db.tearDown()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
// TestPreparedStmt is test for prepared statement
func TestPreparedStmt(t *testing.T) {
db.tearDown()
db.mustExec("CREATE TABLE t (count INT)")
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}
// Benchmarks need to use panic() since b.Error errors are lost when
// running via testing.Benchmark() I would like to run these via go
// test -bench but calling Benchmark() from a benchmark test
// currently hangs go.
// BenchmarkExec is benchmark for exec
func BenchmarkExec(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := db.Exec("select 1"); err != nil {
panic(err)
}
}
}
// BenchmarkQuery is benchmark for query
func BenchmarkQuery(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// BenchmarkParams is benchmark for params
func BenchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// BenchmarkStmt is benchmark for statement
func BenchmarkStmt(b *testing.B) {
st, err := db.Prepare("select ?, ?, ?, ?")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// BenchmarkRows is benchmark for rows
func BenchmarkRows(b *testing.B) {
db.once.Do(makeBench)
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := db.Query("select * from bench")
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}
// BenchmarkStmtRows is benchmark for statement rows
func BenchmarkStmtRows(b *testing.B) {
db.once.Do(makeBench)
st, err := db.Prepare("select * from bench")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := st.Query()
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}

View File

@ -1,11 +1,13 @@
language: go
go:
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- 1.2.x
- 1.3.x
- 1.4.x
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- tip
install:

View File

@ -65,4 +65,6 @@ var (
ShouldHappenWithin = assertions.ShouldHappenWithin
ShouldNotHappenWithin = assertions.ShouldNotHappenWithin
ShouldBeChronological = assertions.ShouldBeChronological
ShouldBeError = assertions.ShouldBeError
)

View File

@ -1 +0,0 @@
package main

View File

@ -35,13 +35,13 @@ func init() {
func flags() {
flag.IntVar(&port, "port", 8080, "The port at which to serve http.")
flag.StringVar(&host, "host", "127.0.0.1", "The host at which to serve http.")
flag.DurationVar(&nap, "poll", quarterSecond, "The interval to wait between polling the file system for changes (default: 250ms).")
flag.IntVar(&packages, "packages", 10, "The number of packages to test in parallel. Higher == faster but more costly in terms of computing. (default: 10)")
flag.DurationVar(&nap, "poll", quarterSecond, "The interval to wait between polling the file system for changes.")
flag.IntVar(&packages, "packages", 10, "The number of packages to test in parallel. Higher == faster but more costly in terms of computing.")
flag.StringVar(&gobin, "gobin", "go", "The path to the 'go' binary (default: search on the PATH).")
flag.BoolVar(&cover, "cover", true, "Enable package-level coverage statistics. Requires Go 1.2+ and the go cover tool. (default: true)")
flag.IntVar(&depth, "depth", -1, "The directory scanning depth. If -1, scan infinitely deep directory structures. 0: scan working directory. 1+: Scan into nested directories, limited to value. (default: -1)")
flag.BoolVar(&cover, "cover", true, "Enable package-level coverage statistics. Requires Go 1.2+ and the go cover tool.")
flag.IntVar(&depth, "depth", -1, "The directory scanning depth. If -1, scan infinitely deep directory structures. 0: scan working directory. 1+: Scan into nested directories, limited to value.")
flag.StringVar(&timeout, "timeout", "0", "The test execution timeout if none is specified in the *.goconvey file (default is '0', which is the same as not providing this option).")
flag.StringVar(&watchedSuffixes, "watchedSuffixes", ".go", "A comma separated list of file suffixes to watch for modifications (default: .go).")
flag.StringVar(&watchedSuffixes, "watchedSuffixes", ".go", "A comma separated list of file suffixes to watch for modifications.")
flag.StringVar(&excludedDirs, "excludedDirs", "vendor,node_modules", "A comma separated list of directories that will be excluded from being watched")
flag.StringVar(&workDir, "workDir", "", "set goconvey working directory (default current directory)")
flag.BoolVar(&autoLaunchBrowser, "launchBrowser", true, "toggle auto launching of browser (default: true)")

Some files were not shown because too many files have changed in this diff Show More