Mode eingebaut

This commit is contained in:
konrad 2017-08-31 18:37:43 +02:00
parent fc139036c0
commit f2d411d666
485 changed files with 141934 additions and 95401 deletions

View File

@ -11,4 +11,10 @@ TODO::
* ~~Kofis löschen~~ * ~~Kofis löschen~~
* ~~Logout~~ * ~~Logout~~
* Konfis auf der Frontseite mit Websockets updaten * Konfis auf der Frontseite mit Websockets updaten
* Mode hinzufügen: in der Config soll man zwischen Gemeinden und Konfis umschalten können, so dass entweder die Gemeinden oder Konfis gegeneinander spielen * ~~Mode hinzufügen: in der Config soll man zwischen Gemeinden und Konfis umschalten können, so dass entweder die Gemeinden oder Konfis gegeneinander spielen~~
* Alles in Fenster packen
* Random Port (wenn man irgendwie den laufenden port finden kann)
* Front-Tabelle schöner
* Mini-Doku
* Inbetriebnahme
* Bedienung

40
add.go
View File

@ -6,6 +6,10 @@ import (
) )
func addKonfi(c echo.Context) error { func addKonfi(c echo.Context) error {
//Config
SiteConf := initConfig()
//Datenbankverbindung aufbauen //Datenbankverbindung aufbauen
db := DBinit() db := DBinit()
@ -18,16 +22,36 @@ func addKonfi(c echo.Context) error {
//Wenn eingeloggt //Wenn eingeloggt
if logged != nil { if logged != nil {
kofi := new(Kofi) // Mode nach Kofis
kofi.Name = c.FormValue("name") if SiteConf.Mode == 0 {
kofi.Gemeinde = c.FormValue("gemeinde") kofi := new(Kofi)
kofi.Name = c.FormValue("name")
kofi.Gemeinde = c.FormValue("gemeinde")
// Einfügen
_, err := db.Insert(kofi)
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
return c.JSON(http.StatusInternalServerError, Message{"Error."})
} else if SiteConf.Mode == 1 { // Mode nach Gemeinden
gemeinde := new(Gemeinde)
gemeinde.Name = c.FormValue("name")
_, err := db.Insert(gemeinde)
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
return c.JSON(http.StatusInternalServerError, Message{"Error."})
//Aktuelle Coins holen
_, err := db.Insert(kofi)
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
} }
return c.JSON(http.StatusOK, Message{"Error."})
return c.JSON(http.StatusInternalServerError, Message{"Wrong Mode."})
} else { } else {
return c.JSON(http.StatusOK, Message{"Login first."}) return c.JSON(http.StatusOK, Message{"Login first."})
} }

View File

@ -3,13 +3,18 @@ package main
import ( import (
"github.com/labstack/echo" "github.com/labstack/echo"
"net/http" "net/http"
"strconv"
) )
type Loggedin struct { type AdminInfos struct {
Loggedin bool Loggedin bool
Mode int
} }
func adminHandler(c echo.Context) error { func adminHandler(c echo.Context) error {
//Config
SiteConf := initConfig()
rw := c.Response() rw := c.Response()
r := c.Request() r := c.Request()
@ -20,8 +25,8 @@ func adminHandler(c echo.Context) error {
loggedin := sess.Get("login") loggedin := sess.Get("login")
if loggedin != nil { if loggedin != nil {
return c.Render(http.StatusOK, "admin", Loggedin{true}) return c.Render(http.StatusOK, "admin_mode_" + strconv.Itoa(SiteConf.Mode), AdminInfos{true, SiteConf.Mode})
} else { } else {
return c.Render(http.StatusOK, "admin", Loggedin{false}) return c.Render(http.StatusOK, "login", AdminInfos{false, SiteConf.Mode})
} }
} }

View File

@ -1,9 +1,13 @@
function getList() { function getList() {
$.getJSON('/list?asc=asc', function (data) { $.getJSON('/list?asc=asc', function (data) {
//console.log(data); //console.log(data);
$("#konfis").html(''); $("#list").html('');
$.each(data, function (i, item) { $.each(data, function (i, item) {
$("#konfis").append('<tr id="kcoins_row_' + item.ID + '"> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td id="kcoins_display_' + item.ID + '">' + item.KCoins + '</td><td><span class="ui action input" id="kcoins_container_' + item.ID + '"><input type="number" value="0" id="kcoins_' + item.ID + '" name="kcoins" autocomplete="off" /><button class="ui right labeled icon button green" onclick="updateCoins(\'' + item.ID + '\');"><i class="right dollar icon"></i>KonfiCoins Hinzufügen</button></span>&nbsp;&nbsp;&nbsp;&nbsp;<button class="ui button red" onclick="deleteKonfi(\'' + item.ID + '\');" id="kcoins_container_' + item.ID + '">Konfi Löschen</button></td></tr>'); if(item.Gemeinde !== undefined) {
$("#list").append('<tr id="kcoins_row_' + item.ID + '"> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td id="kcoins_display_' + item.ID + '">' + item.KCoins + '</td><td><span class="ui action input" id="kcoins_container_' + item.ID + '"><input type="number" value="0" id="kcoins_' + item.ID + '" name="kcoins" autocomplete="off" /><button class="ui right labeled icon button green" onclick="updateCoins(\'' + item.ID + '\');"><i class="right dollar icon"></i>KonfiCoins Hinzufügen</button></span>&nbsp;&nbsp;&nbsp;&nbsp;<button class="ui button red" onclick="deleteKonfi(\'' + item.ID + '\');" id="kcoins_container_' + item.ID + '">Konfi Löschen</button></td></tr>');
} else {
$("#list").append('<tr id="kcoins_row_' + item.ID + '"> <td>' + item.Name + '</td> <td id="kcoins_display_' + item.ID + '">' + item.KCoins + '</td><td><span class="ui action input" id="kcoins_container_' + item.ID + '"><input type="number" value="0" id="kcoins_' + item.ID + '" name="kcoins" autocomplete="off" /><button class="ui right labeled icon button green" onclick="updateCoins(\'' + item.ID + '\');"><i class="right dollar icon"></i>KonfiCoins Hinzufügen</button></span>&nbsp;&nbsp;&nbsp;&nbsp;<button class="ui button red" onclick="deleteGemeinde(\'' + item.ID + '\');" id="kcoins_container_' + item.ID + '">Gemeinde Löschen</button></td></tr>');
}
}); });
}); });
} }
@ -15,7 +19,7 @@ function updateCoins(id) {
var addcoins = $('#kcoins_' + id).val(); var addcoins = $('#kcoins_' + id).val();
if(addcoins != 0) { if(addcoins != 0) {
$('#kcoins_container_' + id).addClass('disabled'); $('#coins_container_' + id).addClass('disabled');
$.ajax({ $.ajax({
url: '/update', url: '/update',
@ -23,11 +27,11 @@ function updateCoins(id) {
data: 'id=' + id + '&addcoins=' + addcoins, data: 'id=' + id + '&addcoins=' + addcoins,
success: function (msg) { success: function (msg) {
console.log(msg); console.log(msg);
$('#kcoins_container_' + id).removeClass('disabled'); $('#coins_container_' + id).removeClass('disabled');
if (msg.Message == 'success') { if (msg.Message == 'success') {
$('#kcoins_' + id).val("0"); $('#kcoins_' + id).val("0");
$('#kcoins_display_' + id).html(msg.Kofi.KCoins); $('#kcoins_display_' + id).html(msg.Data.KCoins);
} else { } else {
$('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>'); $('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>');
@ -72,6 +76,42 @@ function deleteKonfi(id) {
; ;
} }
function deleteGemeinde(id) {
console.log('Delete', id);
$('#kcoins_container_' + id).addClass('disabled');
$('.ui.basic.gemeindedel.modal')
.modal({
closable : false,
duration: 200,
onDeny : function(){
$('#kcoins_container_' + id).removeClass('disabled');
return true;
},
onApprove : function() {
$.ajax({
url: '/delete',
method: 'POST',
data: 'id=' + id,
success: function (msg) {
console.log(msg);
if (msg.Message == 'success') {
//$('#kcoins_row_' + id).remove();
getList();
$('#msg').html('<div class="ui success message" style="display: block;">Die Gemeinde wurde erfolgreich gelöscht.</div>');
} else {
$('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>');
}
}
});
}
})
.modal('show')
;
}
// Konfi hinzufügen
$('.ui.kofiadd.modal') $('.ui.kofiadd.modal')
.modal({ .modal({
duration: 200, duration: 200,
@ -101,6 +141,35 @@ $('.ui.kofiadd.modal')
.modal('attach events', '.addKofi.button', 'show') .modal('attach events', '.addKofi.button', 'show')
; ;
$('.ui.gemeindeadd.modal')
.modal({
duration: 200,
onApprove : function() {
$('.loader').addClass('active');
console.log('bul');
$.ajax({
url: '/add',
method: 'POST',
data: 'name=' + $('#name').val(),
success: function (msg) {
$('.loader').removeClass('active');
console.log(msg);
if (msg.Message == 'success') {
$('#name').val('');
getList();
$('#msg').html('<div class="ui success message" style="display: block;">Die Gemeinde wurde erfolgreich hinzugefügt.</div>');
} else {
$('#msg').html('<div class="ui error message" style="display: block;">Ein Fehler trat auf.</div>');
}
}
});
}
})
.modal('attach events', '.addGemeinde.button', 'show')
;
$('.ui.kofiupload.modal') $('.ui.kofiupload.modal')
.modal('attach events', '.ui.right.labeled.icon.uploadKofis.button.blue', 'show') .modal('attach events', '.ui.right.labeled.icon.uploadKofis.button.blue', 'show')
; ;

View File

@ -3,7 +3,11 @@ setInterval(function() {
//console.log(data); //console.log(data);
$( "#konfis" ).html(''); $( "#konfis" ).html('');
$.each( data, function( i, item ) { $.each( data, function( i, item ) {
$( "#konfis" ).append('<tr> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td>' + item.KCoins + '</td></tr>'); if (item.Gemeinde != undefined) {
$( "#konfis" ).append('<tr> <td>' + item.Name + '</td> <td>' + item.Gemeinde + '</td> <td>' + item.KCoins + '</td></tr>');
} else {
$( "#konfis" ).append('<tr> <td>' + item.Name + '</td> <td>' + item.KCoins + '</td></tr>');
}
}); });
}); });
}, 1000); }, 1000);

View File

@ -10,6 +10,7 @@ type Configuration struct {
AdminPassword string AdminPassword string
Interface string Interface string
DBFile string DBFile string
Mode int
} }
var SiteConf Configuration = Configuration{} var SiteConf Configuration = Configuration{}

View File

@ -1,3 +1,12 @@
; Das Adminpasswort, wird benötigt, um sich unter /admin einzuloggen
AdminPassword = geheim AdminPassword = geheim
; 0 = Konfis sind selbstständig
; 1 = Gemeinden spielen gegeneinancer
Mode = 1
; Serverkram
; Das Interface inkl. Port, auf dem der Webserver läuft
Interface = :8080 Interface = :8080
; Hier wird die Datenbank gespeichert
DBFile = ./data.db DBFile = ./data.db

View File

@ -7,6 +7,10 @@ import (
) )
func deleteKonfi(c echo.Context) error { func deleteKonfi(c echo.Context) error {
//Config
SiteConf := initConfig()
//Datenbankverbindung aufbauen //Datenbankverbindung aufbauen
db := DBinit() db := DBinit()
@ -22,12 +26,19 @@ func deleteKonfi(c echo.Context) error {
id, _ := strconv.Atoi(c.FormValue("id")) id, _ := strconv.Atoi(c.FormValue("id"))
//Löschen //Löschen
_, err := db.Id(id).Delete(&Kofi{}) if SiteConf.Mode == 0 {
if err == nil { _, err := db.Id(id).Delete(&Kofi{})
return c.JSON(http.StatusOK, Message{"success"}) if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
} else if SiteConf.Mode == 1{
_, err := db.Id(id).Delete(&Gemeinde{})
if err == nil {
return c.JSON(http.StatusOK, Message{"success"})
}
} }
return c.JSON(http.StatusOK, Message{"Error."}) return c.JSON(http.StatusInternalServerError, Message{"Error."})
} else { } else {
return c.JSON(http.StatusOK, Message{"Login first."}) return c.JSON(http.StatusForbidden, Message{"Login first."})
} }
} }

View File

@ -10,21 +10,46 @@ func getList(c echo.Context) error {
//Datenbankverbindung aufbauen //Datenbankverbindung aufbauen
db := DBinit() db := DBinit()
//Daten holen und anzeigen //Config
var kofi []Kofi SiteConf := initConfig()
asc := c.QueryParam("asc")
if asc == "" { if SiteConf.Mode == 0 {
err := db.OrderBy("KCoins DESC").Find(&kofi) //Daten holen und anzeigen
if err != nil { var kofi []Kofi
fmt.Println(err) asc := c.QueryParam("asc")
if asc == "" {
err := db.OrderBy("KCoins DESC").Find(&kofi)
if err != nil {
fmt.Println(err)
}
} else {
err := db.OrderBy("Name ASC").Find(&kofi)
if err != nil {
fmt.Println(err)
}
} }
} else {
err := db.OrderBy("Name ASC").Find(&kofi) //Template
if err != nil { return c.JSON(http.StatusOK, kofi)
fmt.Println(err) } else if SiteConf.Mode == 1 {
//Daten holen und anzeigen
var gemeinden []Gemeinde
asc := c.QueryParam("asc")
if asc == "" {
err := db.OrderBy("KCoins DESC").Find(&gemeinden)
if err != nil {
fmt.Println(err)
}
} else {
err := db.OrderBy("Name ASC").Find(&gemeinden)
if err != nil {
fmt.Println(err)
}
} }
//Template
return c.JSON(http.StatusOK, gemeinden)
} }
//Template return c.HTML(http.StatusInternalServerError, "Error. (Wrong mode)")
return c.JSON(http.StatusOK, kofi)
} }

View File

@ -79,7 +79,8 @@ func main() {
//DB init - Create tables //DB init - Create tables
db := DBinit() db := DBinit()
db.CreateTables(&Kofi{}) db.Sync(&Kofi{})
db.Sync(&Gemeinde{})
//Start the server //Start the server
e.Logger.SetLevel(log.ERROR) e.Logger.SetLevel(log.ERROR)

View File

@ -6,6 +6,10 @@ import (
) )
func showList(c echo.Context) error { func showList(c echo.Context) error {
//Config
SiteConf := initConfig()
//Template //Template
return c.Render(http.StatusOK, "index", Message{"schinken"}) return c.Render(http.StatusOK, "index", SiteConf)
} }

View File

@ -1,4 +1,4 @@
{{define "admin"}} {{define "admin_mode_0"}}
<!DOCTYPE html> <!DOCTYPE html>
<html lang="de"> <html lang="de">
<head> <head>
@ -37,7 +37,7 @@
<th>Bearbeiten</th> <th>Bearbeiten</th>
</tr> </tr>
</thead> </thead>
<tbody id="konfis"> <tbody id="list">
<tr> <tr>
<td colspan="3">Laden...</td> <td colspan="3">Laden...</td>
</tr> </tr>
@ -75,14 +75,12 @@
Konfi hinzufügen Konfi hinzufügen
</div> </div>
<div class="image content"> <div class="image content">
<form action="#" method="post">
<div class="ui input"> <div class="ui input">
<input type="text" id="name" placeholder="Name"/> <input type="text" id="name" placeholder="Name"/>
</div><br/><br/> </div><br/><br/>
<div class="ui input"> <div class="ui input">
<input type="text" id="gemeinde" placeholder="Gemeinde"/> <input type="text" id="gemeinde" placeholder="Gemeinde"/>
</div> </div>
</form>
</div> </div>
<div class="actions"> <div class="actions">
<div class="ui black deny button"> <div class="ui black deny button">
@ -108,41 +106,14 @@
<div class="ui black deny button"> <div class="ui black deny button">
Abbrechen Abbrechen
</div> </div>
<div class="ui positive button"> <div class="ui positive button">include
Hochladen Hochladen
</div> </div>
</div> </div>
</div> </div>
{{else}}
<body style="background: url(/assets/bg.jpg) no-repeat center fixed">
<div class="ui middle aligned center aligned grid" style="width: 30em; margin: 37vh auto;">
<div class="column">
<h2 class="ui header">
Kasino Admin
</h2>
<form class="ui large form" id="loginform" method="post">
<div class="ui segment">
<div class="field">
<div class="ui left icon input">
<i class="lock icon"></i>
<input type="password" name="password" id="password" placeholder="Passwort" autofocus>
</div>
</div>
<div class="ui fluid large blue submit button">Login</div>
</div>
<div id="msg"></div>
</form>
</div>
</div>
{{end}}
{{if .Loggedin}}
<script src="/assets/js/admin.js"></script> <script src="/assets/js/admin.js"></script>
{{else}}
<script src="/assets/js/login.js"></script>
{{end}} {{end}}
</body> </body>
</html> </html>
{{end}} {{end}}include

95
tpl/admin_mode_1.html Normal file
View File

@ -0,0 +1,95 @@
{{define "admin_mode_1"}}
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Kasino Admin</title>
<link rel="stylesheet" type="text/css" href="/assets/semantic/semantic.min.css">
<script src="/assets/js/jquery-3.1.1.min.js"></script>
<script src="/assets/semantic/semantic.min.js"></script>
</head>
{{if .Loggedin}}
<body>
<div style="width: 98%; margin: 0 auto;">
<h1>Kasino Admin
<div class="ui inline loader"></div>
</h1>
<a href="/logout" class="ui right labeled icon button blue" style="float: right;">
<i class="right sign out icon"></i>
Ausloggen
</a>
<p>
<a class="ui right labeled icon addGemeinde button green">
<i class="right plus icon"></i>
Gemeinde Hinzufügen
</a>
</p>
<div id="msg"></div>
<table width="100%" border="0" cellpadding="0" cellspacing="0" class="ui celled table">
<thead>
<tr>
<th>Name</th>
<th>KonfiCoins</th>
<th>Bearbeiten</th>
</tr>
</thead>
<tbody id="list">
<tr>
<td colspan="3">Laden...</td>
</tr>
</tbody>
</table>
</div>
<!-- Modals -->
<!-- Gemeinde löschen -->
<div class="ui small basic gemeindedel modal">
<div class="ui icon header">
<i class="trash outline icon"></i>
Gemeinde löschen
</div>
<div class="content">
<p>Willst du diese Gemeinde wirklich löschen? Diese Aktion kann nicht Rückgängig gemacht werden!</p>
</div>
<div class="actions">
<div class="ui red basic cancel inverted button">
<i class="remove icon"></i>
Nein
</div>
<div class="ui green ok inverted button">
<i class="checkmark icon"></i>
Ja!
</div>
</div>
</div>
<!-- Gemeinde hinzufügen -->
<div class="ui gemeindeadd modal">
<i class="close icon"></i>
<div class="header">
Gemeinde hinzufügen
</div>
<div class="image content">
<div class="ui input">
<input type="text" id="name" placeholder="Name"/>
</div>
</div>
<div class="actions">
<div class="ui black deny button">
Abbrechen
</div>
<div class="ui positive right labeled icon button">
Hinzufügen
<i class="checkmark icon"></i>
</div>
</div>
</div>
<script src="/assets/js/admin.js"></script>
{{end}}
</body>
</html>
{{end}}

View File

@ -10,7 +10,9 @@
<table width="100%" border="0" cellpadding="0" cellspacing="0"> <table width="100%" border="0" cellpadding="0" cellspacing="0">
<tr class="top"> <tr class="top">
<th scope="col">Name</th> <th scope="col">Name</th>
{{if eq .Mode 0}}
<th scope="col">Gemeinde</th> <th scope="col">Gemeinde</th>
{{end}}
<th scope="col">Eingezahlte KonfiCoins</th> <th scope="col">Eingezahlte KonfiCoins</th>
</tr> </tr>
<tbody id="konfis"> <tbody id="konfis">

37
tpl/login.html Normal file
View File

@ -0,0 +1,37 @@
{{define "login"}}
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Kasino Admin</title>
<link rel="stylesheet" type="text/css" href="/assets/semantic/semantic.min.css">
<script src="/assets/js/jquery-3.1.1.min.js"></script>
<script src="/assets/semantic/semantic.min.js"></script>
</head>
<body style="background: url(/assets/bg.jpg) no-repeat center fixed">
<div class="ui middle aligned center aligned grid" style="width: 30em; margin: 37vh auto;">
<div class="column">
<h2 class="ui header">
Kasino Admin
</h2>
<form class="ui large form" id="loginform" method="post">
<div class="ui segment">
<div class="field">
<div class="ui left icon input">
<i class="lock icon"></i>
<input type="password" name="password" id="password" placeholder="Passwort" autofocus>
</div>
</div>
<div class="ui fluid large blue submit button">Login</div>
</div>
<div id="msg"></div>
</form>
</div>
</div>
<script src="/assets/js/login.js"></script>
</body>
</html>
{{end}}

View File

@ -7,6 +7,10 @@ import (
) )
func update(c echo.Context) error { func update(c echo.Context) error {
//Config
SiteConf := initConfig()
//Datenbankverbindung aufbauen //Datenbankverbindung aufbauen
db := DBinit() db := DBinit()
@ -22,23 +26,43 @@ func update(c echo.Context) error {
id, _ := strconv.Atoi(c.FormValue("id")) id, _ := strconv.Atoi(c.FormValue("id"))
addcoins, _ := strconv.Atoi(c.FormValue("addcoins")) addcoins, _ := strconv.Atoi(c.FormValue("addcoins"))
//Aktuelle Coins holen if SiteConf.Mode == 0 {
var kofi = Kofi{ID: id} //Aktuelle Coins holen
has, err := db.Get(&kofi) var kofi= Kofi{ID: id}
checkErr(err) has, err := db.Get(&kofi)
if has { checkErr(err)
newCoins := kofi.KCoins + addcoins if has {
newCoins := kofi.KCoins + addcoins
//Updaten //Updaten
kofi.KCoins = newCoins kofi.KCoins = newCoins
_, err := db.Id(id).Update(kofi) _, err := db.Id(id).Update(kofi)
if err == nil { if err == nil {
return c.JSON(http.StatusOK, UpdatedMessage{"success", kofi}) return c.JSON(http.StatusOK, UpdatedMessageKofi{"success", kofi})
}
} }
} else if SiteConf.Mode == 1{
var gemeinde = Gemeinde{ID: id}
has, err := db.Get(&gemeinde)
checkErr(err)
if has {
newCoins := gemeinde.KCoins + addcoins
// Updaten
gemeinde.KCoins = newCoins
_, err := db.ID(id).Update(gemeinde)
if err == nil {
return c.JSON(http.StatusOK, UpdatedMessageGemeinde{"success", gemeinde})
}
}
} }
return c.JSON(http.StatusOK, Message{"Error."}) return c.JSON(http.StatusInternalServerError, Message{"Error."})
} else { } else {
return c.JSON(http.StatusOK, Message{"Login first."}) return c.JSON(http.StatusForbidden, Message{"Login first."})
} }
} }

View File

@ -9,13 +9,24 @@ type Kofi struct {
KCoins int KCoins int
} }
type Gemeinde struct {
ID int `xorm:"pk autoincr"`
Name string
KCoins int
}
type Message struct { type Message struct {
Message string Message string
} }
type UpdatedMessage struct { type UpdatedMessageKofi struct {
Message string Message string
Kofi Kofi Data Kofi
}
type UpdatedMessageGemeinde struct {
Message string
Data Gemeinde
} }
//CheckError //CheckError

View File

@ -1,6 +1,8 @@
# A pure Go MSSQL driver for Go's database/sql package # A pure Go MSSQL driver for Go's database/sql package
[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb)
[![Build status](https://ci.appveyor.com/api/projects/status/ujv21jd241h8o5s7?svg=true)](https://ci.appveyor.com/project/denisenk/go-mssqldb) [![Build status](https://ci.appveyor.com/api/projects/status/ujv21jd241h8o5s7?svg=true)](https://ci.appveyor.com/project/denisenk/go-mssqldb)
[![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb)
## Install ## Install

View File

@ -44,7 +44,9 @@ before_test:
Start-Service "SQLBrowser" Start-Service "SQLBrowser"
- sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;" - sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;"
- sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version" - sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version"
- pip install codecov
test_script: test_script:
- go test -v -cover - go test -race -coverprofile=coverage.txt -covermode=atomic
- codecov -f coverage.txt

View File

@ -1,11 +1,9 @@
package mssql package mssql
import ( import (
"database/sql/driver"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"net"
) )
type packetType uint8 type packetType uint8
@ -53,19 +51,6 @@ func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
return w return w
} }
func checkBadConn(err error) error {
if err == io.EOF {
return driver.ErrBadConn
}
switch err.(type) {
case net.Error:
return driver.ErrBadConn
default:
return err
}
}
func (rw *tdsBuffer) ResizeBuffer(packetsizei int) { func (rw *tdsBuffer) ResizeBuffer(packetsizei int) {
if len(rw.rbuf) != packetsizei { if len(rw.rbuf) != packetsizei {
newbuf := make([]byte, packetsizei) newbuf := make([]byte, packetsizei)
@ -152,7 +137,7 @@ func (r *tdsBuffer) readNextPacket() error {
var err error var err error
err = binary.Read(r.transport, binary.BigEndian, &header) err = binary.Read(r.transport, binary.BigEndian, &header)
if err != nil { if err != nil {
return checkBadConn(err) return err
} }
offset := uint16(binary.Size(header)) offset := uint16(binary.Size(header))
if int(header.Size) > len(r.rbuf) { if int(header.Size) > len(r.rbuf) {
@ -163,7 +148,7 @@ func (r *tdsBuffer) readNextPacket() error {
} }
_, err = io.ReadFull(r.transport, r.rbuf[offset:header.Size]) _, err = io.ReadFull(r.transport, r.rbuf[offset:header.Size])
if err != nil { if err != nil {
return checkBadConn(err) return err
} }
r.rpos = offset r.rpos = offset
r.rsize = header.Size r.rsize = header.Size
@ -206,7 +191,7 @@ func (r *tdsBuffer) byte() byte {
func (r *tdsBuffer) ReadFull(buf []byte) { func (r *tdsBuffer) ReadFull(buf []byte) {
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { if err != nil {
badStreamPanic(checkBadConn(err)) badStreamPanic(err)
} }
} }

View File

@ -71,11 +71,18 @@ func TestBulkcopy(t *testing.T) {
conn := open(t) conn := open(t)
defer conn.Close() defer conn.Close()
setupTable(conn, tableName) err := setupTable(conn, tableName)
if (err != nil) {
t.Error("Setup table failed: ", err.Error())
return
}
log.Println("Preparing copyin statement")
stmt, err := conn.Prepare(CopyIn(tableName, MssqlBulkOptions{}, columns...)) stmt, err := conn.Prepare(CopyIn(tableName, MssqlBulkOptions{}, columns...))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
log.Printf("Executing copy in statement %d time with %d values", i+1, len(values))
_, err = stmt.Exec(values...) _, err = stmt.Exec(values...)
if err != nil { if err != nil {
t.Error("AddRow failed: ", err.Error()) t.Error("AddRow failed: ", err.Error())
@ -142,8 +149,8 @@ func compareValue(a interface{}, expected interface{}) bool {
return reflect.DeepEqual(expected, a) return reflect.DeepEqual(expected, a)
} }
} }
func setupTable(conn *sql.DB, tableName string) {
func setupTable(conn *sql.DB, tableName string) (err error) {
tablesql := `CREATE TABLE ` + tableName + ` ( tablesql := `CREATE TABLE ` + tableName + ` (
[id] [int] IDENTITY(1,1) NOT NULL, [id] [int] IDENTITY(1,1) NOT NULL,
[test_nvarchar] [nvarchar](50) NULL, [test_nvarchar] [nvarchar](50) NULL,
@ -186,9 +193,9 @@ func setupTable(conn *sql.DB, tableName string) {
[id] ASC [id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];` ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];`
_, err := conn.Exec(tablesql) _, err = conn.Exec(tablesql)
if err != nil { if err != nil {
log.Fatal("tablesql failed:", err) log.Fatal("tablesql failed:", err)
} }
return
} }

View File

@ -58,6 +58,38 @@ type MssqlConn struct {
transactionCtx context.Context transactionCtx context.Context
processQueryText bool processQueryText bool
connectionGood bool
}
func (c *MssqlConn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
// but we return actual error instead of ErrBadConn
// this will cause connection to stay in a pool
// but next request to this connection will return ErrBadConn
// it might be possible to revise this hack after
// https://github.com/golang/go/issues/20807
// is implemented
if err == nil {
return nil
}
if err == io.EOF {
return driver.ErrBadConn
}
switch err.(type) {
case net.Error:
c.connectionGood = false
return err
case StreamError:
c.connectionGood = false
return err
default:
return err
}
} }
func (c *MssqlConn) simpleProcessResp(ctx context.Context) error { func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
@ -67,18 +99,21 @@ func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
switch token := tok.(type) { switch token := tok.(type) {
case doneStruct: case doneStruct:
if token.isError() { if token.isError() {
return token.getError() return c.checkBadConn(token.getError())
} }
case error: case error:
return token return c.checkBadConn(token)
} }
} }
return nil return nil
} }
func (c *MssqlConn) Commit() error { func (c *MssqlConn) Commit() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendCommitRequest(); err != nil { if err := c.sendCommitRequest(); err != nil {
return err return c.checkBadConn(err)
} }
return c.simpleProcessResp(c.transactionCtx) return c.simpleProcessResp(c.transactionCtx)
} }
@ -98,8 +133,11 @@ func (c *MssqlConn) sendCommitRequest() error {
} }
func (c *MssqlConn) Rollback() error { func (c *MssqlConn) Rollback() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendRollbackRequest(); err != nil { if err := c.sendRollbackRequest(); err != nil {
return err return c.checkBadConn(err)
} }
return c.simpleProcessResp(c.transactionCtx) return c.simpleProcessResp(c.transactionCtx)
} }
@ -122,12 +160,19 @@ func (c *MssqlConn) Begin() (driver.Tx, error) {
return c.begin(context.Background(), isolationUseCurrent) return c.begin(context.Background(), isolationUseCurrent)
} }
func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) { func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
err := c.sendBeginRequest(ctx, tdsIsolation) if !c.connectionGood {
if err != nil { return nil, driver.ErrBadConn
return nil, err
} }
return c.processBeginResponse(ctx) err = c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, c.checkBadConn(err)
}
tx, err = c.processBeginResponse(ctx)
if err != nil {
return nil, c.checkBadConn(err)
}
return
} }
func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error { func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
@ -183,7 +228,7 @@ func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
} }
} }
conn := &MssqlConn{sess, context.Background(), d.processQueryText} conn := &MssqlConn{sess: sess, transactionCtx: context.Background(), processQueryText: d.processQueryText, connectionGood: true}
conn.sess.log = d.log conn.sess.log = d.log
return conn, nil return conn, nil
} }
@ -206,6 +251,9 @@ type queryNotifSub struct {
} }
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) { func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query) return c.prepareCopyIn(query)
} }
@ -315,9 +363,12 @@ func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.queryContext(context.Background(), convertOldArgs(args)) return s.queryContext(context.Background(), convertOldArgs(args))
} }
func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) { func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
if err := s.sendQuery(args); err != nil { if !s.c.connectionGood {
return nil, err return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
} }
return s.processQueryResponse(ctx) return s.processQueryResponse(ctx)
} }
@ -343,13 +394,13 @@ loop:
break loop break loop
case doneStruct: case doneStruct:
if token.isError() { if token.isError() {
return nil, token.getError() return nil, s.c.checkBadConn(token.getError())
} }
case error: case error:
return nil, token return nil, s.c.checkBadConn(token)
} }
} }
res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel} res = &MssqlRows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
return return
} }
@ -357,11 +408,17 @@ func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), convertOldArgs(args)) return s.exec(context.Background(), convertOldArgs(args))
} }
func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) { func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
if err := s.sendQuery(args); err != nil { if !s.c.connectionGood {
return nil, err return nil, driver.ErrBadConn
} }
return s.processExec(ctx) if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
}
if res, err = s.processExec(ctx); err != nil {
return nil, s.c.checkBadConn(err)
}
return
} }
func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) { func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
@ -389,7 +446,7 @@ func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err err
} }
type MssqlRows struct { type MssqlRows struct {
sess *tdsSession stmt *MssqlStmt
cols []columnStruct cols []columnStruct
tokchan chan tokenStruct tokchan chan tokenStruct
@ -415,6 +472,9 @@ func (rc *MssqlRows) Columns() (res []string) {
} }
func (rc *MssqlRows) Next(dest []driver.Value) error { func (rc *MssqlRows) Next(dest []driver.Value) error {
if !rc.stmt.c.connectionGood {
return driver.ErrBadConn
}
if rc.nextCols != nil { if rc.nextCols != nil {
return io.EOF return io.EOF
} }
@ -430,10 +490,10 @@ func (rc *MssqlRows) Next(dest []driver.Value) error {
return nil return nil
case doneStruct: case doneStruct:
if tokdata.isError() { if tokdata.isError() {
return tokdata.getError() return rc.stmt.c.checkBadConn(tokdata.getError())
} }
case error: case error:
return tokdata return rc.stmt.c.checkBadConn(tokdata)
} }
} }
return io.EOF return io.EOF

View File

@ -343,7 +343,6 @@ func TestExec(t *testing.T) {
} }
func TestShortTimeout(t *testing.T) { func TestShortTimeout(t *testing.T) {
t.Skip("TODO: fix this test")
if testing.Short() { if testing.Short() {
t.Skip("short") t.Skip("short")
} }
@ -841,7 +840,6 @@ func TestConnectionClosing(t *testing.T) {
} }
func TestBeginTranError(t *testing.T) { func TestBeginTranError(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t) checkConnStr(t)
drv := driverWithProcess(t) drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String()) conn, err := drv.open(makeConnStr(t).String())
@ -878,10 +876,13 @@ func TestBeginTranError(t *testing.T) {
case driver.ErrBadConn: case driver.ErrBadConn:
t.Error("processBeginResponse should fail with error different from ErrBadConn but it did") t.Error("processBeginResponse should fail with error different from ErrBadConn but it did")
} }
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
} }
func TestCommitTranError(t *testing.T) { func TestCommitTranError(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t) checkConnStr(t)
drv := driverWithProcess(t) drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String()) conn, err := drv.open(makeConnStr(t).String())
@ -919,6 +920,10 @@ func TestCommitTranError(t *testing.T) {
t.Error("simpleProcessResp should fail with error different from ErrBadConn but it did") t.Error("simpleProcessResp should fail with error different from ErrBadConn but it did")
} }
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
// reopen connection // reopen connection
conn, err = drv.open(makeConnStr(t).String()) conn, err = drv.open(makeConnStr(t).String())
defer conn.Close() defer conn.Close()
@ -936,7 +941,6 @@ func TestCommitTranError(t *testing.T) {
} }
func TestRollbackTranError(t *testing.T) { func TestRollbackTranError(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t) checkConnStr(t)
drv := driverWithProcess(t) drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String()) conn, err := drv.open(makeConnStr(t).String())
@ -974,6 +978,10 @@ func TestRollbackTranError(t *testing.T) {
t.Error("simpleProcessResp should fail with error different from ErrBadConn but it did") t.Error("simpleProcessResp should fail with error different from ErrBadConn but it did")
} }
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
// reopen connection // reopen connection
conn, err = drv.open(makeConnStr(t).String()) conn, err = drv.open(makeConnStr(t).String())
defer conn.Close() defer conn.Close()
@ -1031,7 +1039,6 @@ func TestSendQueryErrors(t *testing.T) {
} }
func TestProcessQueryErrors(t *testing.T) { func TestProcessQueryErrors(t *testing.T) {
t.Skip("TODO: fix this test")
checkConnStr(t) checkConnStr(t)
drv := driverWithProcess(t) drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr(t).String()) conn, err := drv.open(makeConnStr(t).String())
@ -1056,6 +1063,10 @@ func TestProcessQueryErrors(t *testing.T) {
if err == driver.ErrBadConn { if err == driver.ErrBadConn {
t.Error("processQueryResponse expected to fail with error other than ErrBadConn but it failed with it") t.Error("processQueryResponse expected to fail with error other than ErrBadConn but it failed with it")
} }
if conn.connectionGood {
t.Fatal("Connection should be in a bad state")
}
} }
func TestSendExecErrors(t *testing.T) { func TestSendExecErrors(t *testing.T) {

View File

@ -198,36 +198,6 @@ func TestConnect(t *testing.T) {
defer conn.Close() defer conn.Close()
} }
func TestBadConnect(t *testing.T) {
t.Skip("TODO: fix this test")
var badDSNs []string
if parsed, err := url.Parse(os.Getenv("SQLSERVER_DSN")); err == nil {
parsed.User = url.UserPassword("baduser", "badpwd")
badDSNs = append(badDSNs, parsed.String())
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("INSTANCE")) > 0 {
badDSNs = append(badDSNs,
fmt.Sprintf(
"Server=%s\\%s;User ID=baduser;Password=badpwd",
os.Getenv("HOST"), os.Getenv("INSTANCE"),
),
)
}
SetLogger(testLogger{t})
for _, badDsn := range badDSNs {
conn, err := sql.Open("mssql", badDsn)
if err != nil {
t.Error("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Error("Ping should fail for connection: ", badDsn)
}
}
}
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) { func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
stmt, err := conn.Prepare("select 1 as a") stmt, err := conn.Prepare("select 1 as a")
if err != nil { if err != nil {

View File

@ -923,6 +923,10 @@ func makeGoLangScanType(ti typeInfo) reflect.Type {
default: default:
panic("invalid size of MONEYN") panic("invalid size of MONEYN")
} }
case typeDateTim4:
return reflect.TypeOf(time.Time{})
case typeDateTime:
return reflect.TypeOf(time.Time{})
case typeDateTimeN: case typeDateTimeN:
switch ti.Size { switch ti.Size {
case 4: case 4:

View File

@ -1,11 +1,10 @@
sudo: false sudo: false
language: go language: go
go: go:
- 1.4 - 1.5.x
- 1.5 - 1.6.x
- 1.6 - 1.7.x
- 1.7 - 1.8.x
- 1.8
- master - master
script: script:

63
vendor/github.com/go-ini/ini/ini.go generated vendored
View File

@ -24,10 +24,8 @@ import (
"os" "os"
"regexp" "regexp"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
) )
const ( const (
@ -37,7 +35,7 @@ const (
// Maximum allowed depth when recursively substituing variable names. // Maximum allowed depth when recursively substituing variable names.
_DEPTH_VALUES = 99 _DEPTH_VALUES = 99
_VERSION = "1.28.1" _VERSION = "1.28.2"
) )
// Version returns current package version literal. // Version returns current package version literal.
@ -398,10 +396,7 @@ func (f *File) Append(source interface{}, others ...interface{}) error {
return f.Reload() return f.Reload()
} }
// WriteToIndent writes content into io.Writer with given indention. func (f *File) writeToBuffer(indent string) (*bytes.Buffer, error) {
// If PrettyFormat has been set to be true,
// it will align "=" sign with spaces under each section.
func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
equalSign := "=" equalSign := "="
if PrettyFormat { if PrettyFormat {
equalSign = " = " equalSign = " = "
@ -415,14 +410,14 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
if sec.Comment[0] != '#' && sec.Comment[0] != ';' { if sec.Comment[0] != '#' && sec.Comment[0] != ';' {
sec.Comment = "; " + sec.Comment sec.Comment = "; " + sec.Comment
} }
if _, err = buf.WriteString(sec.Comment + LineBreak); err != nil { if _, err := buf.WriteString(sec.Comment + LineBreak); err != nil {
return 0, err return nil, err
} }
} }
if i > 0 || DefaultHeader { if i > 0 || DefaultHeader {
if _, err = buf.WriteString("[" + sname + "]" + LineBreak); err != nil { if _, err := buf.WriteString("[" + sname + "]" + LineBreak); err != nil {
return 0, err return nil, err
} }
} else { } else {
// Write nothing if default section is empty // Write nothing if default section is empty
@ -432,8 +427,8 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
} }
if sec.isRawSection { if sec.isRawSection {
if _, err = buf.WriteString(sec.rawBody); err != nil { if _, err := buf.WriteString(sec.rawBody); err != nil {
return 0, err return nil, err
} }
continue continue
} }
@ -469,8 +464,8 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
if key.Comment[0] != '#' && key.Comment[0] != ';' { if key.Comment[0] != '#' && key.Comment[0] != ';' {
key.Comment = "; " + key.Comment key.Comment = "; " + key.Comment
} }
if _, err = buf.WriteString(key.Comment + LineBreak); err != nil { if _, err := buf.WriteString(key.Comment + LineBreak); err != nil {
return 0, err return nil, err
} }
} }
@ -488,8 +483,8 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
} }
for _, val := range key.ValueWithShadows() { for _, val := range key.ValueWithShadows() {
if _, err = buf.WriteString(kname); err != nil { if _, err := buf.WriteString(kname); err != nil {
return 0, err return nil, err
} }
if key.isBooleanType { if key.isBooleanType {
@ -510,20 +505,31 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
} else if !f.options.IgnoreInlineComment && strings.ContainsAny(val, "#;") { } else if !f.options.IgnoreInlineComment && strings.ContainsAny(val, "#;") {
val = "`" + val + "`" val = "`" + val + "`"
} }
if _, err = buf.WriteString(equalSign + val + LineBreak); err != nil { if _, err := buf.WriteString(equalSign + val + LineBreak); err != nil {
return 0, err return nil, err
} }
} }
} }
if PrettySection { if PrettySection {
// Put a line between sections // Put a line between sections
if _, err = buf.WriteString(LineBreak); err != nil { if _, err := buf.WriteString(LineBreak); err != nil {
return 0, err return nil, err
} }
} }
} }
return buf, nil
}
// WriteToIndent writes content into io.Writer with given indention.
// If PrettyFormat has been set to be true,
// it will align "=" sign with spaces under each section.
func (f *File) WriteToIndent(w io.Writer, indent string) (int64, error) {
buf, err := f.writeToBuffer(indent)
if err != nil {
return 0, err
}
return buf.WriteTo(w) return buf.WriteTo(w)
} }
@ -536,23 +542,12 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
func (f *File) SaveToIndent(filename, indent string) error { func (f *File) SaveToIndent(filename, indent string) error {
// Note: Because we are truncating with os.Create, // Note: Because we are truncating with os.Create,
// so it's safer to save to a temporary file location and rename afte done. // so it's safer to save to a temporary file location and rename afte done.
tmpPath := filename + "." + strconv.Itoa(time.Now().Nanosecond()) + ".tmp" buf, err := f.writeToBuffer(indent)
defer os.Remove(tmpPath)
fw, err := os.Create(tmpPath)
if err != nil { if err != nil {
return err return err
} }
if _, err = f.WriteToIndent(fw, indent); err != nil { return ioutil.WriteFile(filename, buf.Bytes(), 0666)
fw.Close()
return err
}
fw.Close()
// Remove old file and rename the new one.
os.Remove(filename)
return os.Rename(tmpPath, filename)
} }
// SaveTo writes content to file system. // SaveTo writes content to file system.

View File

@ -13,12 +13,13 @@ const (
ONLYFROMDB ONLYFROMDB
) )
// database column // Column defines database column
type Column struct { type Column struct {
Name string Name string
TableName string TableName string
FieldName string FieldName string
SQLType SQLType SQLType SQLType
IsJSON bool
Length int Length int
Length2 int Length2 int
Nullable bool Nullable bool

View File

@ -247,6 +247,18 @@ type Row struct {
err error // deferred error for easy chaining err error // deferred error for easy chaining
} }
// ErrorRow return an error row
func ErrorRow(err error) *Row {
return &Row{
err: err,
}
}
// NewRow from rows
func NewRow(rows *Rows, err error) *Row {
return &Row{rows, err}
}
func (row *Row) Columns() ([]string, error) { func (row *Row) Columns() ([]string, error) {
if row.err != nil { if row.err != nil {
return nil, row.err return nil, row.err

View File

@ -6,10 +6,6 @@ Xorm is a simple and powerful ORM for Go.
[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm) [![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm)
[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) [![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
# Notice
The last master version is not backwards compatible. You should use `engine.ShowSQL()` and `engine.Logger().SetLevel()` instead of `engine.ShowSQL = `, `engine.ShowInfo = ` and so on.
# Features # Features
* Struct <-> Table Mapping Support * Struct <-> Table Mapping Support
@ -38,7 +34,7 @@ Drivers for Go's sql package which currently support database/sql includes:
* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) * Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/tree/master/godrv)
* Postgres: [github.com/lib/pq](https://github.com/lib/pq) * Postgres: [github.com/lib/pq](https://github.com/lib/pq)
@ -52,6 +48,14 @@ Drivers for Go's sql package which currently support database/sql includes:
# Changelog # Changelog
* **v0.6.3**
* merge tests to main project
* add `Exist` function
* add `SumInt` function
* Mysql now support read and create column comment.
* fix time related bugs.
* fix some other bugs.
* **v0.6.2** * **v0.6.2**
* refactor tag parse methods * refactor tag parse methods
* add Scan features to Get * add Scan features to Get
@ -64,22 +68,6 @@ methods can use `builder.Cond` as parameter
* add Sum, SumInt, SumInt64 and NotIn methods * add Sum, SumInt, SumInt64 and NotIn methods
* some bugs fixed * some bugs fixed
* **v0.5.0**
* logging interface changed
* some bugs fixed
* **v0.4.5**
* many bugs fixed
* extends support unlimited deepth
* Delete Limit support
* **v0.4.4**
* ql database expriment support
* tidb database expriment support
* sql.NullString and etc. field support
* select ForUpdate support
* many bugs fixed
[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16) [More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16)
# Installation # Installation
@ -126,7 +114,7 @@ results, err := engine.Query("select * from user")
results, err := engine.QueryString("select * from user") results, err := engine.QueryString("select * from user")
``` ```
* `Execute` runs a SQL string, it returns `affetcted` and `error` * `Execute` runs a SQL string, it returns `affected` and `error`
```Go ```Go
affected, err := engine.Exec("update user set age = ? where name = ?", age, name) affected, err := engine.Exec("update user set age = ? where name = ?", age, name)
@ -168,6 +156,25 @@ has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ? // SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
* Check if one record exist on table
```Go
has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{
Name: "test1",
})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
```
* Query multiple records from database, also you can use join and extends * Query multiple records from database, also you can use join and extends
```Go ```Go
@ -260,6 +267,14 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
# Cases # Cases
* [studygolang](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
* [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana)
* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) * [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader)
* [Wego](http://github.com/go-tango/wego) * [Wego](http://github.com/go-tango/wego)
@ -268,8 +283,6 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter) * [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [Gorevel](http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel) * [Gorevel](http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel)
* [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) * [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker)

View File

@ -8,10 +8,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm) [![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm)
[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) [![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
# 注意
最新的版本有不兼容的更新,您必须使用 `engine.ShowSQL()``engine.Logger().SetLevel()` 来替代 `engine.ShowSQL = `, `engine.ShowInfo = ` 等等。
## 特性 ## 特性
* 支持Struct和数据库表之间的灵活映射并支持自动同步 * 支持Struct和数据库表之间的灵活映射并支持自动同步
@ -56,6 +52,15 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 更新日志 ## 更新日志
* **v0.6.3**
* 合并单元测试到主工程
* 新增`Exist`方法
* 新增`SumInt`方法
* Mysql新增读取和创建字段注释支持
* 新增`SetConnMaxLifetime`方法
* 修正了时间相关的Bug
* 修复了一些其它Bug
* **v0.6.2** * **v0.6.2**
* 重构Tag解析方式 * 重构Tag解析方式
* Get方法新增类似Scan的特性 * Get方法新增类似Scan的特性
@ -72,18 +77,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* logging接口进行不兼容改变 * logging接口进行不兼容改变
* Bug修正 * Bug修正
* **v0.4.5**
* bug修正
* extends 支持无限级
* Delete Limit 支持
* **v0.4.4**
* Tidb 数据库支持
* QL 试验性支持
* sql.NullString支持
* ForUpdate 支持
* bug修正
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16) [更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## 安装 ## 安装
@ -172,6 +165,25 @@ has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ? // SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
* 检测记录是否存在
```Go
has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{
Name: "test1",
})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
```
* 查询多条记录当然可以使用Join和extends来组合使用 * 查询多条记录当然可以使用Join和extends来组合使用
```Go ```Go
@ -263,6 +275,14 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
# 案例 # 案例
* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
* [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana)
* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) * [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader)
* [Wego](http://github.com/go-tango/wego) * [Wego](http://github.com/go-tango/wego)
@ -271,8 +291,6 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter) * [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter)
* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs)
* [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) * [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker)
* [Gobuild.io](http://gobuild.io) - [github.com/shxsun/gobuild](http://github.com/shxsun/gobuild) * [Gobuild.io](http://gobuild.io) - [github.com/shxsun/gobuild](http://github.com/shxsun/gobuild)

View File

@ -21,7 +21,7 @@ database:
test: test:
override: override:
# './...' is a relative pattern which means all subdirectories # './...' is a relative pattern which means all subdirectories
- go test -v -race -db="sqlite3::mysql::postgres" -conn_str="./test.db::root:@/xorm_test::dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic - go test -v -race -db="sqlite3::mysql::mymysql::postgres" -conn_str="./test.db::root:@/xorm_test::xorm_test/root/::dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh

View File

@ -90,7 +90,7 @@ another is Rows
5. Update one or more records 5. Update one or more records
affected, err := engine.Id(...).Update(&user) affected, err := engine.ID(...).Update(&user)
// UPDATE user SET ... // UPDATE user SET ...
6. Delete one or more records, Delete MUST has condition 6. Delete one or more records, Delete MUST has condition

View File

@ -169,7 +169,7 @@ func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
} }
// SqlType will be depracated, please use SQLType instead // SqlType will be deprecated, please use SQLType instead
// //
// Deprecated: use SQLType instead // Deprecated: use SQLType instead
func (engine *Engine) SqlType(c *core.Column) string { func (engine *Engine) SqlType(c *core.Column) string {
@ -205,14 +205,14 @@ func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
// you can use NoCache() // you can use NoCache()
func (engine *Engine) NoCache() *Session { func (engine *Engine) NoCache() *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.NoCache() return session.NoCache()
} }
// NoCascade If you do not want to auto cascade load object // NoCascade If you do not want to auto cascade load object
func (engine *Engine) NoCascade() *Session { func (engine *Engine) NoCascade() *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.NoCascade() return session.NoCascade()
} }
@ -245,7 +245,7 @@ func (engine *Engine) Dialect() core.Dialect {
// NewSession New a session // NewSession New a session
func (engine *Engine) NewSession() *Session { func (engine *Engine) NewSession() *Session {
session := &Session{Engine: engine} session := &Session{engine: engine}
session.Init() session.Init()
return session return session
} }
@ -259,7 +259,6 @@ func (engine *Engine) Close() error {
func (engine *Engine) Ping() error { func (engine *Engine) Ping() error {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
engine.logger.Infof("PING DATABASE %v", engine.DriverName())
return session.Ping() return session.Ping()
} }
@ -274,36 +273,6 @@ func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) {
} }
} }
func (engine *Engine) logSQLQueryTime(sqlStr string, args []interface{}, executionBlock func() (*core.Stmt, *core.Rows, error)) (*core.Stmt, *core.Rows, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
stmt, res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[SQL] %s %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
return stmt, res, err
}
return executionBlock()
}
func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[sql] %s [args] %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[sql] %s - took: %v", sqlStr, execDuration)
}
return res, err
}
return executionBlock()
}
// Sql provides raw sql input parameter. When you have a complex SQL statement // Sql provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.
// //
@ -320,7 +289,7 @@ func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
// This code will execute "select * from user" and set the records to users // This code will execute "select * from user" and set the records to users
func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.SQL(query, args...) return session.SQL(query, args...)
} }
@ -329,14 +298,14 @@ func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
// invoked. Call NoAutoTime if you dont' want to fill automatically. // invoked. Call NoAutoTime if you dont' want to fill automatically.
func (engine *Engine) NoAutoTime() *Session { func (engine *Engine) NoAutoTime() *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.NoAutoTime() return session.NoAutoTime()
} }
// NoAutoCondition disable auto generate Where condition from bean or not // NoAutoCondition disable auto generate Where condition from bean or not
func (engine *Engine) NoAutoCondition(no ...bool) *Session { func (engine *Engine) NoAutoCondition(no ...bool) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.NoAutoCondition(no...) return session.NoAutoCondition(no...)
} }
@ -570,56 +539,56 @@ func (engine *Engine) tbName(v reflect.Value) string {
// Cascade use cascade or not // Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Cascade(trueOrFalse...) return session.Cascade(trueOrFalse...)
} }
// Where method provide a condition query // Where method provide a condition query
func (engine *Engine) Where(query interface{}, args ...interface{}) *Session { func (engine *Engine) Where(query interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Where(query, args...) return session.Where(query, args...)
} }
// Id will be depracated, please use ID instead // Id will be deprecated, please use ID instead
func (engine *Engine) Id(id interface{}) *Session { func (engine *Engine) Id(id interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Id(id) return session.Id(id)
} }
// ID method provoide a condition as (id) = ? // ID method provoide a condition as (id) = ?
func (engine *Engine) ID(id interface{}) *Session { func (engine *Engine) ID(id interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.ID(id) return session.ID(id)
} }
// Before apply before Processor, affected bean is passed to closure arg // Before apply before Processor, affected bean is passed to closure arg
func (engine *Engine) Before(closures func(interface{})) *Session { func (engine *Engine) Before(closures func(interface{})) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Before(closures) return session.Before(closures)
} }
// After apply after insert Processor, affected bean is passed to closure arg // After apply after insert Processor, affected bean is passed to closure arg
func (engine *Engine) After(closures func(interface{})) *Session { func (engine *Engine) After(closures func(interface{})) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.After(closures) return session.After(closures)
} }
// Charset set charset when create table, only support mysql now // Charset set charset when create table, only support mysql now
func (engine *Engine) Charset(charset string) *Session { func (engine *Engine) Charset(charset string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Charset(charset) return session.Charset(charset)
} }
// StoreEngine set store engine when create table, only support mysql now // StoreEngine set store engine when create table, only support mysql now
func (engine *Engine) StoreEngine(storeEngine string) *Session { func (engine *Engine) StoreEngine(storeEngine string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.StoreEngine(storeEngine) return session.StoreEngine(storeEngine)
} }
@ -628,35 +597,35 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session {
// but distinct will not provide id // but distinct will not provide id
func (engine *Engine) Distinct(columns ...string) *Session { func (engine *Engine) Distinct(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Distinct(columns...) return session.Distinct(columns...)
} }
// Select customerize your select columns or contents // Select customerize your select columns or contents
func (engine *Engine) Select(str string) *Session { func (engine *Engine) Select(str string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Select(str) return session.Select(str)
} }
// Cols only use the parameters as select or update columns // Cols only use the parameters as select or update columns
func (engine *Engine) Cols(columns ...string) *Session { func (engine *Engine) Cols(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Cols(columns...) return session.Cols(columns...)
} }
// AllCols indicates that all columns should be use // AllCols indicates that all columns should be use
func (engine *Engine) AllCols() *Session { func (engine *Engine) AllCols() *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.AllCols() return session.AllCols()
} }
// MustCols specify some columns must use even if they are empty // MustCols specify some columns must use even if they are empty
func (engine *Engine) MustCols(columns ...string) *Session { func (engine *Engine) MustCols(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.MustCols(columns...) return session.MustCols(columns...)
} }
@ -667,77 +636,84 @@ func (engine *Engine) MustCols(columns ...string) *Session {
// it will use parameters's columns // it will use parameters's columns
func (engine *Engine) UseBool(columns ...string) *Session { func (engine *Engine) UseBool(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.UseBool(columns...) return session.UseBool(columns...)
} }
// Omit only not use the parameters as select or update columns // Omit only not use the parameters as select or update columns
func (engine *Engine) Omit(columns ...string) *Session { func (engine *Engine) Omit(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Omit(columns...) return session.Omit(columns...)
} }
// Nullable set null when column is zero-value and nullable for update // Nullable set null when column is zero-value and nullable for update
func (engine *Engine) Nullable(columns ...string) *Session { func (engine *Engine) Nullable(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Nullable(columns...) return session.Nullable(columns...)
} }
// In will generate "column IN (?, ?)" // In will generate "column IN (?, ?)"
func (engine *Engine) In(column string, args ...interface{}) *Session { func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.In(column, args...) return session.In(column, args...)
} }
// NotIn will generate "column NOT IN (?, ?)"
func (engine *Engine) NotIn(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.NotIn(column, args...)
}
// Incr provides a update string like "column = column + ?" // Incr provides a update string like "column = column + ?"
func (engine *Engine) Incr(column string, arg ...interface{}) *Session { func (engine *Engine) Incr(column string, arg ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Incr(column, arg...) return session.Incr(column, arg...)
} }
// Decr provides a update string like "column = column - ?" // Decr provides a update string like "column = column - ?"
func (engine *Engine) Decr(column string, arg ...interface{}) *Session { func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Decr(column, arg...) return session.Decr(column, arg...)
} }
// SetExpr provides a update string like "column = {expression}" // SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session { func (engine *Engine) SetExpr(column string, expression string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.SetExpr(column, expression) return session.SetExpr(column, expression)
} }
// Table temporarily change the Get, Find, Update's table // Table temporarily change the Get, Find, Update's table
func (engine *Engine) Table(tableNameOrBean interface{}) *Session { func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Table(tableNameOrBean) return session.Table(tableNameOrBean)
} }
// Alias set the table alias // Alias set the table alias
func (engine *Engine) Alias(alias string) *Session { func (engine *Engine) Alias(alias string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Alias(alias) return session.Alias(alias)
} }
// Limit will generate "LIMIT start, limit" // Limit will generate "LIMIT start, limit"
func (engine *Engine) Limit(limit int, start ...int) *Session { func (engine *Engine) Limit(limit int, start ...int) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Limit(limit, start...) return session.Limit(limit, start...)
} }
// Desc will generate "ORDER BY column1 DESC, column2 DESC" // Desc will generate "ORDER BY column1 DESC, column2 DESC"
func (engine *Engine) Desc(colNames ...string) *Session { func (engine *Engine) Desc(colNames ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Desc(colNames...) return session.Desc(colNames...)
} }
@ -749,35 +725,35 @@ func (engine *Engine) Desc(colNames ...string) *Session {
// //
func (engine *Engine) Asc(colNames ...string) *Session { func (engine *Engine) Asc(colNames ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Asc(colNames...) return session.Asc(colNames...)
} }
// OrderBy will generate "ORDER BY order" // OrderBy will generate "ORDER BY order"
func (engine *Engine) OrderBy(order string) *Session { func (engine *Engine) OrderBy(order string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.OrderBy(order) return session.OrderBy(order)
} }
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Join(joinOperator, tablename, condition, args...) return session.Join(joinOperator, tablename, condition, args...)
} }
// GroupBy generate group by statement // GroupBy generate group by statement
func (engine *Engine) GroupBy(keys string) *Session { func (engine *Engine) GroupBy(keys string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.GroupBy(keys) return session.GroupBy(keys)
} }
// Having generate having statement // Having generate having statement
func (engine *Engine) Having(conditions string) *Session { func (engine *Engine) Having(conditions string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Having(conditions) return session.Having(conditions)
} }
@ -1208,6 +1184,9 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
// table, column, index, unique. but will not delete or change anything. // table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually. // If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error { func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableName := engine.tbName(v) tableName := engine.tbName(v)
@ -1216,14 +1195,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
s := engine.NewSession() isExist, err := session.Table(bean).isTableExist(tableName)
defer s.Close()
isExist, err := s.Table(bean).isTableExist(tableName)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
err = engine.CreateTables(bean) err = session.createTable(bean)
if err != nil { if err != nil {
return err return err
} }
@ -1234,11 +1211,11 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}*/ }*/
var isEmpty bool var isEmpty bool
if isEmpty { if isEmpty {
err = engine.DropTables(bean) err = session.dropTable(bean)
if err != nil { if err != nil {
return err return err
} }
err = engine.CreateTables(bean) err = session.createTable(bean)
if err != nil { if err != nil {
return err return err
} }
@ -1249,9 +1226,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if !isExist { if !isExist {
session := engine.NewSession() if err := session.statement.setRefValue(v); err != nil {
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
return err return err
} }
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
@ -1262,9 +1237,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
for name, index := range table.Indexes { for name, index := range table.Indexes {
session := engine.NewSession() if err := session.statement.setRefValue(v); err != nil {
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
return err return err
} }
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
@ -1273,9 +1246,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if !isExist { if !isExist {
session := engine.NewSession() if err := session.statement.setRefValue(v); err != nil {
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
return err return err
} }
@ -1290,9 +1261,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if !isExist { if !isExist {
session := engine.NewSession() if err := session.statement.setRefValue(v); err != nil {
defer session.Close()
if err := session.Statement.setRefValue(v); err != nil {
return err return err
} }
@ -1328,7 +1297,7 @@ func (engine *Engine) CreateTables(beans ...interface{}) error {
} }
for _, bean := range beans { for _, bean := range beans {
err = session.CreateTable(bean) err = session.createTable(bean)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
return err return err
@ -1348,7 +1317,7 @@ func (engine *Engine) DropTables(beans ...interface{}) error {
} }
for _, bean := range beans { for _, bean := range beans {
err = session.DropTable(bean) err = session.dropTable(bean)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
return err return err
@ -1385,6 +1354,13 @@ func (engine *Engine) QueryString(sqlStr string, args ...interface{}) ([]map[str
return session.QueryString(sqlStr, args...) return session.QueryString(sqlStr, args...)
} }
// QueryInterface runs a raw sql and return records as []map[string]interface{}
func (engine *Engine) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) {
session := engine.NewSession()
defer session.Close()
return session.QueryInterface(sqlStr, args...)
}
// Insert one or more records // Insert one or more records
func (engine *Engine) Insert(beans ...interface{}) (int64, error) { func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
@ -1426,6 +1402,13 @@ func (engine *Engine) Get(bean interface{}) (bool, error) {
return session.Get(bean) return session.Get(bean)
} }
// Exist returns true if the record exist otherwise return false
func (engine *Engine) Exist(bean ...interface{}) (bool, error) {
session := engine.NewSession()
defer session.Close()
return session.Exist(bean...)
}
// Find retrieve records from table, condiBeans's non-empty fields // Find retrieve records from table, condiBeans's non-empty fields
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct // are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct // map[int64]*Struct
@ -1451,10 +1434,10 @@ func (engine *Engine) Rows(bean interface{}) (*Rows, error) {
} }
// Count counts the records. bean's non-empty fields are conditions. // Count counts the records. bean's non-empty fields are conditions.
func (engine *Engine) Count(bean interface{}) (int64, error) { func (engine *Engine) Count(bean ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Count(bean) return session.Count(bean...)
} }
// Sum sum the records by some column. bean's non-empty fields are conditions. // Sum sum the records by some column. bean's non-empty fields are conditions.
@ -1580,7 +1563,7 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session { func (engine *Engine) Unscoped() *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.isAutoClose = true
return session.Unscoped() return session.Unscoped()
} }

View File

@ -67,7 +67,7 @@ func main() {
fmt.Println("users3:", users3) fmt.Println("users3:", users3)
user4 := new(User) user4 := new(User)
has, err := Orm.Id(1).Get(user4) has, err := Orm.ID(1).Get(user4)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
@ -76,7 +76,7 @@ func main() {
fmt.Println("user4:", has, user4) fmt.Println("user4:", has, user4)
user4.Name = "xiaolunwen" user4.Name = "xiaolunwen"
_, err = Orm.Id(1).Update(user4) _, err = Orm.ID(1).Update(user4)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
@ -84,14 +84,14 @@ func main() {
fmt.Println("user4:", user4) fmt.Println("user4:", user4)
user5 := new(User) user5 := new(User)
has, err = Orm.Id(1).Get(user5) has, err = Orm.ID(1).Get(user5)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user5:", has, user5) fmt.Println("user5:", has, user5)
_, err = Orm.Id(1).Delete(new(User)) _, err = Orm.ID(1).Delete(new(User))
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
@ -99,7 +99,7 @@ func main() {
for { for {
user6 := new(User) user6 := new(User)
has, err = Orm.Id(1).Get(user6) has, err = Orm.ID(1).Get(user6)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return

View File

@ -55,7 +55,7 @@ func test(engine *xorm.Engine) {
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
//_, err = engine.Id(1).Delete(u) //_, err = engine.ID(1).Delete(u)
_, err = engine.Delete(u) _, err = engine.Delete(u)
} }
if err != nil { if err != nil {

View File

@ -51,7 +51,7 @@ func main() {
} }
info := LoginInfo{} info := LoginInfo{}
_, err = orm.Id(1).Get(&info) _, err = orm.ID(1).Get(&info)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return

View File

@ -59,7 +59,7 @@ func test(engine *xorm.Engine) {
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
_, err = engine.Id(1).Delete(u) _, err = engine.ID(1).Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)

View File

@ -62,7 +62,7 @@ func test(engine *xorm.Engine) {
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
_, err = engine.Id(1).Delete(u) _, err = engine.ID(1).Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)

View File

@ -48,7 +48,7 @@ func main() {
} }
info := LoginInfo{} info := LoginInfo{}
_, err = orm.Id(1).Get(&info) _, err = orm.ID(1).Get(&info)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return

View File

@ -358,7 +358,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
for _, col := range table.Columns() { for _, col := range table.Columns() {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue continue
} }
} }
@ -397,32 +397,32 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
continue continue
} }
if session.Statement.ColumnStr != "" { if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue continue
} else if _, ok := session.Statement.incrColumns[col.Name]; ok { } else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue continue
} else if _, ok := session.Statement.decrColumns[col.Name]; ok { } else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue continue
} }
} }
// !evalphobia! set fieldValue as nil when column is nullable and zero-value // !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.Statement.nullableMap, col); ok { if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) { if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int var nilValue *int
fieldValue = reflect.ValueOf(nilValue) fieldValue = reflect.ValueOf(nilValue)
} }
} }
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time // if time is non-empty, then set to auto time
val, t := session.Engine.NowTime2(col.SQLType.Name) val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val) args = append(args, val)
var colName = col.Name var colName = col.Name
@ -430,7 +430,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.Statement.checkVersion { } else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1) args = append(args, 1)
} else { } else {
arg, err := session.value2Interface(col, fieldValue) arg, err := session.value2Interface(col, fieldValue)
@ -441,7 +441,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool,
} }
if includeQuote { if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
} else { } else {
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
} }

View File

@ -173,7 +173,7 @@ func TestProcessors(t *testing.T) {
} }
p2 := &ProcessorsStruct{} p2 := &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2) _, err = testEngine.ID(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -308,7 +308,7 @@ func TestProcessors(t *testing.T) {
} }
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2) _, err = testEngine.ID(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -402,7 +402,7 @@ func TestProcessors(t *testing.T) {
for _, elem := range pslice { for _, elem := range pslice {
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = testEngine.Id(elem.Id).Get(p) _, err = testEngine.ID(elem.Id).Get(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -508,7 +508,7 @@ func TestProcessorsTx(t *testing.T) {
} }
session.Close() session.Close()
p2 := &ProcessorsStruct{} p2 := &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2) _, err = testEngine.ID(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -569,7 +569,7 @@ func TestProcessorsTx(t *testing.T) {
} }
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.Id(p.Id).Get(p2) _, err = testEngine.ID(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -616,7 +616,7 @@ func TestProcessorsTx(t *testing.T) {
p = p2 // reset p = p2 // reset
_, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -656,7 +656,7 @@ func TestProcessorsTx(t *testing.T) {
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.Id(insertedId).Get(p2) _, err = testEngine.ID(insertedId).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -729,7 +729,7 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -767,7 +767,7 @@ func TestProcessorsTx(t *testing.T) {
} }
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.Id(insertedId).Get(p2) _, err = testEngine.ID(insertedId).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -813,7 +813,7 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{} // reset p = &ProcessorsStruct{} // reset
_, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -852,7 +852,7 @@ func TestProcessorsTx(t *testing.T) {
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.Id(insertedId).Get(p2) _, err = testEngine.ID(insertedId).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -882,7 +882,7 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)

View File

@ -17,7 +17,6 @@ type Rows struct {
NoTypeCheck bool NoTypeCheck bool
session *Session session *Session
stmt *core.Stmt
rows *core.Rows rows *core.Rows
fields []string fields []string
beanType reflect.Type beanType reflect.Type
@ -29,56 +28,33 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type() rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
defer rows.session.resetStatement()
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error var err error
if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil { if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
return nil, err return nil, err
} }
if len(session.Statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
return nil, ErrTableNotFound return nil, ErrTableNotFound
} }
if rows.session.Statement.RawSQL == "" { if rows.session.statement.RawSQL == "" {
sqlStr, args, err = rows.session.Statement.genGetSQL(bean) sqlStr, args, err = rows.session.statement.genGetSQL(bean)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
sqlStr = rows.session.Statement.RawSQL sqlStr = rows.session.statement.RawSQL
args = rows.session.Statement.RawParams args = rows.session.statement.RawParams
} }
for _, filter := range rows.session.Engine.dialect.Filters() { rows.rows, err = rows.session.queryRows(sqlStr, args...)
sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable) if err != nil {
} rows.lastError = err
rows.Close()
rows.session.saveLastSQL(sqlStr, args...) return nil, err
if rows.session.prepareStmt {
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.rows, err = rows.stmt.Query(args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
} else {
rows.rows, err = rows.session.DB().Query(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
} }
rows.fields, err = rows.rows.Columns() rows.fields, err = rows.rows.Columns()
@ -119,7 +95,7 @@ func (rows *Rows) Scan(bean interface{}) error {
} }
dataStruct := rValue(bean) dataStruct := rValue(bean)
if err := rows.session.Statement.setRefValue(dataStruct); err != nil { if err := rows.session.statement.setRefValue(dataStruct); err != nil {
return err return err
} }
@ -128,13 +104,13 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
_, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.Statement.RefTable) _, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.statement.RefTable)
return err return err
} }
// Close session if session.IsAutoClose is true, and claimed any opened resources // Close session if session.IsAutoClose is true, and claimed any opened resources
func (rows *Rows) Close() error { func (rows *Rows) Close() error {
if rows.session.IsAutoClose { if rows.session.isAutoClose {
defer rows.session.Close() defer rows.session.Close()
} }
@ -142,17 +118,10 @@ func (rows *Rows) Close() error {
if rows.rows != nil { if rows.rows != nil {
rows.lastError = rows.rows.Close() rows.lastError = rows.rows.Close()
if rows.lastError != nil { if rows.lastError != nil {
defer rows.stmt.Close()
return rows.lastError return rows.lastError
} }
} }
if rows.stmt != nil {
rows.lastError = rows.stmt.Close()
}
} else { } else {
if rows.stmt != nil {
defer rows.stmt.Close()
}
if rows.rows != nil { if rows.rows != nil {
defer rows.rows.Close() defer rows.rows.Close()
} }

View File

@ -38,4 +38,31 @@ func TestRows(t *testing.T) {
cnt++ cnt++
} }
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
sess := testEngine.NewSession()
defer sess.Close()
rows1, err := sess.Prepare().Rows(new(UserRows))
assert.NoError(t, err)
defer rows1.Close()
cnt = 0
for rows1.Next() {
err = rows1.Scan(user)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 1, cnt)
rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows))
assert.NoError(t, err)
defer rows2.Close()
cnt = 0
for rows2.Next() {
err = rows2.Scan(user)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 1, cnt)
} }

View File

@ -21,16 +21,16 @@ import (
// kind of database operations. // kind of database operations.
type Session struct { type Session struct {
db *core.DB db *core.DB
Engine *Engine engine *Engine
Tx *core.Tx tx *core.Tx
Statement Statement statement Statement
IsAutoCommit bool isAutoCommit bool
IsCommitedOrRollbacked bool isCommitedOrRollbacked bool
IsAutoClose bool isAutoClose bool
// Automatically reset the statement after operations that execute a SQL // Automatically reset the statement after operations that execute a SQL
// query such as Count(), Find(), Get(), ... // query such as Count(), Find(), Get(), ...
AutoResetStatement bool autoResetStatement bool
// !nashtsai! storing these beans due to yet committed tx // !nashtsai! storing these beans due to yet committed tx
afterInsertBeans map[interface{}]*[]func(interface{}) afterInsertBeans map[interface{}]*[]func(interface{})
@ -60,12 +60,12 @@ func (session *Session) Clone() *Session {
// Init reset the session as the init status. // Init reset the session as the init status.
func (session *Session) Init() { func (session *Session) Init() {
session.Statement.Init() session.statement.Init()
session.Statement.Engine = session.Engine session.statement.Engine = session.engine
session.IsAutoCommit = true session.isAutoCommit = true
session.IsCommitedOrRollbacked = false session.isCommitedOrRollbacked = false
session.IsAutoClose = false session.isAutoClose = false
session.AutoResetStatement = true session.autoResetStatement = true
session.prepareStmt = false session.prepareStmt = false
// !nashtsai! is lazy init better? // !nashtsai! is lazy init better?
@ -88,19 +88,23 @@ func (session *Session) Close() {
if session.db != nil { if session.db != nil {
// When Close be called, if session is a transaction and do not call // When Close be called, if session is a transaction and do not call
// Commit or Rollback, then call Rollback. // Commit or Rollback, then call Rollback.
if session.Tx != nil && !session.IsCommitedOrRollbacked { if session.tx != nil && !session.isCommitedOrRollbacked {
session.Rollback() session.Rollback()
} }
session.Tx = nil session.tx = nil
session.stmtCache = nil session.stmtCache = nil
session.Init()
session.db = nil session.db = nil
} }
} }
// IsClosed returns if session is closed
func (session *Session) IsClosed() bool {
return session.db == nil
}
func (session *Session) resetStatement() { func (session *Session) resetStatement() {
if session.AutoResetStatement { if session.autoResetStatement {
session.Statement.Init() session.statement.Init()
} }
} }
@ -128,75 +132,75 @@ func (session *Session) After(closures func(interface{})) *Session {
// Table can input a string or pointer to struct for special a table to operate. // Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session { func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableNameOrBean) session.statement.Table(tableNameOrBean)
return session return session
} }
// Alias set the table alias // Alias set the table alias
func (session *Session) Alias(alias string) *Session { func (session *Session) Alias(alias string) *Session {
session.Statement.Alias(alias) session.statement.Alias(alias)
return session return session
} }
// NoCascade indicate that no cascade load child object // NoCascade indicate that no cascade load child object
func (session *Session) NoCascade() *Session { func (session *Session) NoCascade() *Session {
session.Statement.UseCascade = false session.statement.UseCascade = false
return session return session
} }
// ForUpdate Set Read/Write locking for UPDATE // ForUpdate Set Read/Write locking for UPDATE
func (session *Session) ForUpdate() *Session { func (session *Session) ForUpdate() *Session {
session.Statement.IsForUpdate = true session.statement.IsForUpdate = true
return session return session
} }
// NoAutoCondition disable generate SQL condition from beans // NoAutoCondition disable generate SQL condition from beans
func (session *Session) NoAutoCondition(no ...bool) *Session { func (session *Session) NoAutoCondition(no ...bool) *Session {
session.Statement.NoAutoCondition(no...) session.statement.NoAutoCondition(no...)
return session return session
} }
// Limit provide limit and offset query condition // Limit provide limit and offset query condition
func (session *Session) Limit(limit int, start ...int) *Session { func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...) session.statement.Limit(limit, start...)
return session return session
} }
// OrderBy provide order by query condition, the input parameter is the content // OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement. // after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session { func (session *Session) OrderBy(order string) *Session {
session.Statement.OrderBy(order) session.statement.OrderBy(order)
return session return session
} }
// Desc provide desc order by query condition, the input parameters are columns. // Desc provide desc order by query condition, the input parameters are columns.
func (session *Session) Desc(colNames ...string) *Session { func (session *Session) Desc(colNames ...string) *Session {
session.Statement.Desc(colNames...) session.statement.Desc(colNames...)
return session return session
} }
// Asc provide asc order by query condition, the input parameters are columns. // Asc provide asc order by query condition, the input parameters are columns.
func (session *Session) Asc(colNames ...string) *Session { func (session *Session) Asc(colNames ...string) *Session {
session.Statement.Asc(colNames...) session.statement.Asc(colNames...)
return session return session
} }
// StoreEngine is only avialble mysql dialect currently // StoreEngine is only avialble mysql dialect currently
func (session *Session) StoreEngine(storeEngine string) *Session { func (session *Session) StoreEngine(storeEngine string) *Session {
session.Statement.StoreEngine = storeEngine session.statement.StoreEngine = storeEngine
return session return session
} }
// Charset is only avialble mysql dialect currently // Charset is only avialble mysql dialect currently
func (session *Session) Charset(charset string) *Session { func (session *Session) Charset(charset string) *Session {
session.Statement.Charset = charset session.statement.Charset = charset
return session return session
} }
// Cascade indicates if loading sub Struct // Cascade indicates if loading sub Struct
func (session *Session) Cascade(trueOrFalse ...bool) *Session { func (session *Session) Cascade(trueOrFalse ...bool) *Session {
if len(trueOrFalse) >= 1 { if len(trueOrFalse) >= 1 {
session.Statement.UseCascade = trueOrFalse[0] session.statement.UseCascade = trueOrFalse[0]
} }
return session return session
} }
@ -204,32 +208,32 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
// NoCache ask this session do not retrieve data from cache system and // NoCache ask this session do not retrieve data from cache system and
// get data from database directly. // get data from database directly.
func (session *Session) NoCache() *Session { func (session *Session) NoCache() *Session {
session.Statement.UseCache = false session.statement.UseCache = false
return session return session
} }
// Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session.Statement.Join(joinOperator, tablename, condition, args...) session.statement.Join(joinOperator, tablename, condition, args...)
return session return session
} }
// GroupBy Generate Group By statement // GroupBy Generate Group By statement
func (session *Session) GroupBy(keys string) *Session { func (session *Session) GroupBy(keys string) *Session {
session.Statement.GroupBy(keys) session.statement.GroupBy(keys)
return session return session
} }
// Having Generate Having statement // Having Generate Having statement
func (session *Session) Having(conditions string) *Session { func (session *Session) Having(conditions string) *Session {
session.Statement.Having(conditions) session.statement.Having(conditions)
return session return session
} }
// DB db return the wrapper of sql.DB // DB db return the wrapper of sql.DB
func (session *Session) DB() *core.DB { func (session *Session) DB() *core.DB {
if session.db == nil { if session.db == nil {
session.db = session.Engine.db session.db = session.engine.db
session.stmtCache = make(map[uint32]*core.Stmt, 0) session.stmtCache = make(map[uint32]*core.Stmt, 0)
} }
return session.db return session.db
@ -242,13 +246,13 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) {
} }
func (session *Session) canCache() bool { func (session *Session) canCache() bool {
if session.Statement.RefTable == nil || if session.statement.RefTable == nil ||
session.Statement.JoinStr != "" || session.statement.JoinStr != "" ||
session.Statement.RawSQL != "" || session.statement.RawSQL != "" ||
!session.Statement.UseCache || !session.statement.UseCache ||
session.Statement.IsForUpdate || session.statement.IsForUpdate ||
session.Tx != nil || session.tx != nil ||
len(session.Statement.selectStr) > 0 { len(session.statement.selectStr) > 0 {
return false return false
} }
return true return true
@ -272,18 +276,18 @@ func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value {
var col *core.Column var col *core.Column
if col = table.GetColumnIdx(key, idx); col == nil { if col = table.GetColumnIdx(key, idx); col == nil {
//session.Engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) //session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
return nil return nil
} }
fieldValue, err := col.ValueOfV(dataStruct) fieldValue, err := col.ValueOfV(dataStruct)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return nil return nil
} }
if !fieldValue.IsValid() || !fieldValue.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
return nil return nil
} }
return fieldValue return fieldValue
@ -528,7 +532,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
} }
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(core.TimeType) {
dbTZ := session.Engine.DatabaseTZ dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil { if col.TimeZone != nil {
dbTZ = col.TimeZone dbTZ = col.TimeZone
} }
@ -541,25 +545,25 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
z, _ := t.Zone() z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone // set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ) t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
} }
t = t.In(session.Engine.TZLocation) t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type || } else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type { rawValueType == core.Int32Type {
hasAssigned = true hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.Engine.TZLocation) t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else { } else {
if d, ok := vv.Interface().([]uint8); ok { if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true hasAssigned = true
t, err := session.byte2Time(col, d) t, err := session.byte2Time(col, d)
if err != nil { if err != nil {
session.Engine.logger.Error("byte2Time error:", err.Error()) session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false hasAssigned = false
} else { } else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@ -568,7 +572,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
hasAssigned = true hasAssigned = true
t, err := session.str2Time(col, d) t, err := session.str2Time(col, d)
if err != nil { if err != nil {
session.Engine.logger.Error("byte2Time error:", err.Error()) session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false hasAssigned = false
} else { } else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@ -581,7 +585,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString // !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil { if err := nulVal.Scan(vv.Interface()); err != nil {
session.Engine.logger.Error("sql.Sanner error:", err.Error()) session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false hasAssigned = false
} }
} else if col.SQLType.IsJson() { } else if col.SQLType.IsJson() {
@ -606,8 +610,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} }
} }
} else if session.Statement.UseCascade { } else if session.statement.UseCascade {
table, err := session.Engine.autoMapType(*fieldValue) table, err := session.engine.autoMapType(*fieldValue)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -627,9 +631,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily // property to be fetched lazily
structInter := reflect.New(fieldValue.Type()) structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession() has, err := session.ID(pk).NoCascade().get(structInter.Interface())
defer newsession.Close()
has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -773,19 +775,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
return pk, nil return pk, nil
} }
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.Engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
session.saveLastSQL(*sqlStr, paramStr...)
}
// saveLastSQL stores executed query information // saveLastSQL stores executed query information
func (session *Session) saveLastSQL(sql string, args ...interface{}) { func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql session.lastSQL = sql
session.lastSQLArgs = args session.lastSQLArgs = args
session.Engine.logSQL(sql, args...) session.engine.logSQL(sql, args...)
} }
// LastSQL returns last query information // LastSQL returns last query information
@ -795,8 +789,8 @@ func (session *Session) LastSQL() (string, []interface{}) {
// tbName get some table's table name // tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string { func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.Statement.AltTableName) > 0 { if len(session.statement.AltTableName) > 0 {
return session.Statement.AltTableName return session.statement.AltTableName
} }
return table.Name return table.Name
@ -804,6 +798,6 @@ func (session *Session) tbNameNoSchema(table *core.Table) string {
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session { func (session *Session) Unscoped() *Session {
session.Statement.Unscoped() session.statement.Unscoped()
return session return session
} }

View File

@ -6,43 +6,43 @@ package xorm
// Incr provides a query string like "count = count + 1" // Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session { func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.Statement.Incr(column, arg...) session.statement.Incr(column, arg...)
return session return session
} }
// Decr provides a query string like "count = count - 1" // Decr provides a query string like "count = count - 1"
func (session *Session) Decr(column string, arg ...interface{}) *Session { func (session *Session) Decr(column string, arg ...interface{}) *Session {
session.Statement.Decr(column, arg...) session.statement.Decr(column, arg...)
return session return session
} }
// SetExpr provides a query string like "column = {expression}" // SetExpr provides a query string like "column = {expression}"
func (session *Session) SetExpr(column string, expression string) *Session { func (session *Session) SetExpr(column string, expression string) *Session {
session.Statement.SetExpr(column, expression) session.statement.SetExpr(column, expression)
return session return session
} }
// Select provides some columns to special // Select provides some columns to special
func (session *Session) Select(str string) *Session { func (session *Session) Select(str string) *Session {
session.Statement.Select(str) session.statement.Select(str)
return session return session
} }
// Cols provides some columns to special // Cols provides some columns to special
func (session *Session) Cols(columns ...string) *Session { func (session *Session) Cols(columns ...string) *Session {
session.Statement.Cols(columns...) session.statement.Cols(columns...)
return session return session
} }
// AllCols ask all columns // AllCols ask all columns
func (session *Session) AllCols() *Session { func (session *Session) AllCols() *Session {
session.Statement.AllCols() session.statement.AllCols()
return session return session
} }
// MustCols specify some columns must use even if they are empty // MustCols specify some columns must use even if they are empty
func (session *Session) MustCols(columns ...string) *Session { func (session *Session) MustCols(columns ...string) *Session {
session.Statement.MustCols(columns...) session.statement.MustCols(columns...)
return session return session
} }
@ -52,7 +52,7 @@ func (session *Session) MustCols(columns ...string) *Session {
// If no parameters, it will use all the bool field of struct, or // If no parameters, it will use all the bool field of struct, or
// it will use parameters's columns // it will use parameters's columns
func (session *Session) UseBool(columns ...string) *Session { func (session *Session) UseBool(columns ...string) *Session {
session.Statement.UseBool(columns...) session.statement.UseBool(columns...)
return session return session
} }
@ -60,25 +60,25 @@ func (session *Session) UseBool(columns ...string) *Session {
// distinct will not be cached because cache system need id, // distinct will not be cached because cache system need id,
// but distinct will not provide id // but distinct will not provide id
func (session *Session) Distinct(columns ...string) *Session { func (session *Session) Distinct(columns ...string) *Session {
session.Statement.Distinct(columns...) session.statement.Distinct(columns...)
return session return session
} }
// Omit Only not use the parameters as select or update columns // Omit Only not use the parameters as select or update columns
func (session *Session) Omit(columns ...string) *Session { func (session *Session) Omit(columns ...string) *Session {
session.Statement.Omit(columns...) session.statement.Omit(columns...)
return session return session
} }
// Nullable Set null when column is zero-value and nullable for update // Nullable Set null when column is zero-value and nullable for update
func (session *Session) Nullable(columns ...string) *Session { func (session *Session) Nullable(columns ...string) *Session {
session.Statement.Nullable(columns...) session.statement.Nullable(columns...)
return session return session
} }
// NoAutoTime means do not automatically give created field and updated field // NoAutoTime means do not automatically give created field and updated field
// the current time on the current session temporarily // the current time on the current session temporarily
func (session *Session) NoAutoTime() *Session { func (session *Session) NoAutoTime() *Session {
session.Statement.UseAutoTime = false session.statement.UseAutoTime = false
return session return session
} }

View File

@ -31,7 +31,39 @@ func TestSetExpr(t *testing.T) {
if testEngine.dialect.DBType() == core.MSSQL { if testEngine.dialect.DBType() == core.MSSQL {
not = "~" not = "~"
} }
cnt, err = testEngine.SetExpr("show", not+" `show`").Id(1).Update(new(User)) cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(User))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
} }
func TestCols(t *testing.T) {
assert.NoError(t, prepareEngine())
type ColsTable struct {
Id int64
Col1 string
Col2 string
}
assertSync(t, new(ColsTable))
_, err := testEngine.Insert(&ColsTable{
Col1: "1",
Col2: "2",
})
assert.NoError(t, err)
sess := testEngine.ID(1)
_, err = sess.Cols("col1").Cols("col2").Update(&ColsTable{
Col1: "",
Col2: "",
})
assert.NoError(t, err)
var tb ColsTable
has, err := testEngine.ID(1).Get(&tb)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "", tb.Col1)
assert.EqualValues(t, "", tb.Col2)
}

View File

@ -17,25 +17,25 @@ func (session *Session) Sql(query string, args ...interface{}) *Session {
// SQL provides raw sql input parameter. When you have a complex SQL statement // SQL provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.
func (session *Session) SQL(query interface{}, args ...interface{}) *Session { func (session *Session) SQL(query interface{}, args ...interface{}) *Session {
session.Statement.SQL(query, args...) session.statement.SQL(query, args...)
return session return session
} }
// Where provides custom query condition. // Where provides custom query condition.
func (session *Session) Where(query interface{}, args ...interface{}) *Session { func (session *Session) Where(query interface{}, args ...interface{}) *Session {
session.Statement.Where(query, args...) session.statement.Where(query, args...)
return session return session
} }
// And provides custom query condition. // And provides custom query condition.
func (session *Session) And(query interface{}, args ...interface{}) *Session { func (session *Session) And(query interface{}, args ...interface{}) *Session {
session.Statement.And(query, args...) session.statement.And(query, args...)
return session return session
} }
// Or provides custom query condition. // Or provides custom query condition.
func (session *Session) Or(query interface{}, args ...interface{}) *Session { func (session *Session) Or(query interface{}, args ...interface{}) *Session {
session.Statement.Or(query, args...) session.statement.Or(query, args...)
return session return session
} }
@ -48,23 +48,23 @@ func (session *Session) Id(id interface{}) *Session {
// ID provides converting id as a query condition // ID provides converting id as a query condition
func (session *Session) ID(id interface{}) *Session { func (session *Session) ID(id interface{}) *Session {
session.Statement.ID(id) session.statement.ID(id)
return session return session
} }
// In provides a query string like "id in (1, 2, 3)" // In provides a query string like "id in (1, 2, 3)"
func (session *Session) In(column string, args ...interface{}) *Session { func (session *Session) In(column string, args ...interface{}) *Session {
session.Statement.In(column, args...) session.statement.In(column, args...)
return session return session
} }
// NotIn provides a query string like "id in (1, 2, 3)" // NotIn provides a query string like "id in (1, 2, 3)"
func (session *Session) NotIn(column string, args ...interface{}) *Session { func (session *Session) NotIn(column string, args ...interface{}) *Session {
session.Statement.NotIn(column, args...) session.statement.NotIn(column, args...)
return session return session
} }
// Conds returns session query conditions // Conds returns session query conditions except auto bean conditions
func (session *Session) Conds() builder.Cond { func (session *Session) Conds() builder.Cond {
return session.Statement.cond return session.statement.cond
} }

View File

@ -83,6 +83,11 @@ func TestBuilder(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist") assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.NotIn("col_name", "col1", "col2").Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 0, len(conds), "records should not exist")
// complex condtions // complex condtions
var where = builder.NewCond() var where = builder.NewCond()
if true { if true {
@ -222,7 +227,7 @@ func TestIn(t *testing.T) {
} }
user := new(Userinfo) user := new(Userinfo)
has, err := testEngine.Id(ids[0]).Get(user) has, err := testEngine.ID(ids[0]).Get(user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -260,3 +265,35 @@ func TestIn(t *testing.T) {
panic(err) panic(err)
} }
} }
func TestFindAndCount(t *testing.T) {
assert.NoError(t, prepareEngine())
type FindAndCount struct {
Id int64
Name string
}
assert.NoError(t, testEngine.Sync2(new(FindAndCount)))
_, err := testEngine.Insert([]FindAndCount{
{
Name: "test1",
},
{
Name: "test2",
},
})
assert.NoError(t, err)
var results []FindAndCount
sess := testEngine.Where("name = ?", "test1")
conds := sess.Conds()
err = sess.Find(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
total, err := testEngine.Where(conds).Count(new(FindAndCount))
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
}

View File

@ -23,7 +23,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
var x time.Time var x time.Time
var err error var err error
var parseLoc = session.Engine.DatabaseTZ var parseLoc = session.engine.DatabaseTZ
if col.TimeZone != nil { if col.TimeZone != nil {
parseLoc = col.TimeZone parseLoc = col.TimeZone
} }
@ -34,27 +34,27 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
sd, err := strconv.ParseInt(sdata, 10, 64) sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil { if err == nil {
x = time.Unix(sd, 0) x = time.Unix(sd, 0)
session.Engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else { } else {
session.Engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
} else if len(sdata) > 19 && strings.Contains(sdata, "-") { } else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.Engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
if err != nil { if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
session.Engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
if err != nil { if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
session.Engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
} else if len(sdata) == 19 && strings.Contains(sdata, "-") { } else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
session.Engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
session.Engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time { } else if col.SQLType.Name == core.Time {
if strings.Contains(sdata, " ") { if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ") ssd := strings.Split(sdata, " ")
@ -62,13 +62,13 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
} }
sdata = strings.TrimSpace(sdata) sdata = strings.TrimSpace(sdata)
if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:] sdata = sdata[len(sdata)-8:]
} }
st := fmt.Sprintf("2006-01-02 %v", sdata) st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
session.Engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else { } else {
outErr = fmt.Errorf("unsupported time format %v", sdata) outErr = fmt.Errorf("unsupported time format %v", sdata)
return return
@ -77,7 +77,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err) outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err)
return return
} }
outTime = x.In(session.Engine.TZLocation) outTime = x.In(session.engine.TZLocation)
return return
} }
@ -105,7 +105,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, x.Interface()) err := json.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
@ -119,7 +119,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, x.Interface()) err := json.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
@ -132,7 +132,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, x.Interface()) err := json.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
@ -156,7 +156,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
} else { } else {
@ -204,8 +204,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
v = x v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.Statement.UseCascade { } else if session.statement.UseCascade {
table, err := session.Engine.autoMapType(*fieldValue) table, err := session.engine.autoMapType(*fieldValue)
if err != nil { if err != nil {
return err return err
} }
@ -227,9 +227,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily // property to be fetched lazily
structInter := reflect.New(fieldValue.Type()) structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession() has, err := session.ID(pk).NoCascade().get(structInter.Interface())
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return err return err
} }
@ -264,7 +262,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, &x) err := json.Unmarshal(data, &x)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@ -275,7 +273,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, &x) err := json.Unmarshal(data, &x)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@ -347,7 +345,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
} else { } else {
@ -372,7 +370,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int(data[0]) x = int(data[0])
} else { } else {
@ -400,7 +398,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == core.MYSQL { session.engine.dialect.DBType() == core.MYSQL {
if len(data) == 1 { if len(data) == 1 {
x = int32(data[0]) x = int32(data[0])
} else { } else {
@ -428,7 +426,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int8(data[0]) x = int8(data[0])
} else { } else {
@ -456,7 +454,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int16(data[0]) x = int16(data[0])
} else { } else {
@ -488,9 +486,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
v = x v = x
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
default: default:
if session.Statement.UseCascade { if session.statement.UseCascade {
structInter := reflect.New(fieldType.Elem()) structInter := reflect.New(fieldType.Elem())
table, err := session.Engine.autoMapType(structInter.Elem()) table, err := session.engine.autoMapType(structInter.Elem())
if err != nil { if err != nil {
return err return err
} }
@ -510,9 +508,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily // property to be fetched lazily
newsession := session.Engine.NewSession() has, err := session.ID(pk).NoCascade().get(structInter.Interface())
defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return err return err
} }
@ -569,7 +565,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if fieldValue.IsNil() { if fieldValue.IsNil() {
return nil, nil return nil, nil
} else if !fieldValue.IsValid() { } else if !fieldValue.IsValid() {
session.Engine.logger.Warn("the field[", col.FieldName, "] is invalid") session.engine.logger.Warn("the field[", col.FieldName, "] is invalid")
return nil, nil return nil, nil
} else { } else {
// !nashtsai! deference pointer type to instance type // !nashtsai! deference pointer type to instance type
@ -587,7 +583,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time) t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
tf := session.Engine.formatColTime(col, t) tf := session.engine.formatColTime(col, t)
return tf, nil return tf, nil
} }
@ -597,7 +593,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return v.Value() return v.Value()
} }
fieldTable, err := session.Engine.autoMapType(fieldValue) fieldTable, err := session.engine.autoMapType(fieldValue)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -611,14 +607,14 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := json.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
return string(bytes), nil return string(bytes), nil
} else if col.SQLType.IsBlob() { } else if col.SQLType.IsBlob() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := json.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
return bytes, nil return bytes, nil
@ -627,7 +623,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := json.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
return string(bytes), nil return string(bytes), nil
@ -639,7 +635,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := json.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
return string(bytes), nil return string(bytes), nil
@ -652,7 +648,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} else { } else {
bytes, err = json.Marshal(fieldValue.Interface()) bytes, err = json.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
} }

View File

@ -13,25 +13,25 @@ import (
) )
func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil || if session.statement.RefTable == nil ||
session.Tx != nil { session.tx != nil {
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.Engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
} }
newsql := session.Statement.convertIDSQL(sqlStr) newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
cacher := session.Engine.getCacher2(session.Statement.RefTable) cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.Statement.TableName() tableName := session.statement.TableName()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
resultsSlice, err := session.query(newsql, args...) resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil { if err != nil {
return err return err
} }
@ -40,7 +40,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
for _, data := range resultsSlice { for _, data := range resultsSlice {
var id int64 var id int64
var pk core.PK = make([]interface{}, 0) var pk core.PK = make([]interface{}, 0)
for _, col := range session.Statement.RefTable.PKColumns() { for _, col := range session.statement.RefTable.PKColumns() {
if v, ok := data[col.Name]; !ok { if v, ok := data[col.Name]; !ok {
return errors.New("no id") return errors.New("no id")
} else if col.SQLType.IsText() { } else if col.SQLType.IsText() {
@ -59,34 +59,33 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
} }
} }
} /*else { } /*else {
session.Engine.LogDebug("delete cache sql %v", newsql) session.engine.LogDebug("delete cache sql %v", newsql)
cacher.DelIds(tableName, genSqlKey(newsql, args)) cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/ }*/
for _, id := range ids { for _, id := range ids {
session.Engine.logger.Debug("[cacheDelete] delete cache obj", tableName, id) session.engine.logger.Debug("[cacheDelete] delete cache obj", tableName, id)
sid, err := id.ToString() sid, err := id.ToString()
if err != nil { if err != nil {
return err return err
} }
cacher.DelBean(tableName, sid) cacher.DelBean(tableName, sid)
} }
session.Engine.logger.Debug("[cacheDelete] clear cache sql", tableName) session.engine.logger.Debug("[cacheDelete] clear cache sql", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
return nil return nil
} }
// Delete records, bean's non-empty fields are conditions // Delete records, bean's non-empty fields are conditions
func (session *Session) Delete(bean interface{}) (int64, error) { func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
if err := session.Statement.setRefValue(rValue(bean)); err != nil { if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err return 0, err
} }
var table = session.Statement.RefTable var table = session.statement.RefTable
// handle before delete processors // handle before delete processors
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
@ -98,15 +97,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
processor.BeforeDelete() processor.BeforeDelete()
} }
condSQL, condArgs, err := session.Statement.genConds(bean) condSQL, condArgs, err := session.statement.genConds(bean)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if len(condSQL) == 0 && session.Statement.LimitN == 0 { if len(condSQL) == 0 && session.statement.LimitN == 0 {
return 0, ErrNeedDeletedCond return 0, ErrNeedDeletedCond
} }
var tableName = session.Engine.Quote(session.Statement.TableName()) var tableName = session.engine.Quote(session.statement.TableName())
var deleteSQL string var deleteSQL string
if len(condSQL) > 0 { if len(condSQL) > 0 {
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL)
@ -115,15 +114,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} }
var orderSQL string var orderSQL string
if len(session.Statement.OrderStr) > 0 { if len(session.statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.Statement.OrderStr) orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
} }
if session.Statement.LimitN > 0 { if session.statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.Statement.LimitN) orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN)
} }
if len(orderSQL) > 0 { if len(orderSQL) > 0 {
switch session.Engine.dialect.DBType() { switch session.engine.dialect.DBType() {
case core.POSTGRES: case core.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
@ -148,7 +147,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
var realSQL string var realSQL string
argsForCache := make([]interface{}, 0, len(condArgs)*2) argsForCache := make([]interface{}, 0, len(condArgs)*2)
if session.Statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL realSQL = deleteSQL
copy(argsForCache, condArgs) copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...) argsForCache = append(condArgs, argsForCache...)
@ -159,12 +158,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
deletedColumn := table.DeletedColumn() deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v",
session.Engine.Quote(session.Statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.Engine.Quote(deletedColumn.Name), session.engine.Quote(deletedColumn.Name),
condSQL) condSQL)
if len(orderSQL) > 0 { if len(orderSQL) > 0 {
switch session.Engine.dialect.DBType() { switch session.engine.dialect.DBType() {
case core.POSTGRES: case core.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
@ -187,12 +186,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} }
} }
// !oinume! Insert NowTime to the head of session.Statement.Params // !oinume! Insert NowTime to the head of session.statement.Params
condArgs = append(condArgs, "") condArgs = append(condArgs, "")
paramsLen := len(condArgs) paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
val, t := session.Engine.NowTime2(deletedColumn.SQLType.Name) val, t := session.engine.NowTime2(deletedColumn.SQLType.Name)
condArgs[0] = val condArgs[0] = val
var colName = deletedColumn.Name var colName = deletedColumn.Name
@ -202,7 +201,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}) })
} }
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache { if cacher := session.engine.getCacher2(session.statement.RefTable); cacher != nil && session.statement.UseCache {
session.cacheDelete(deleteSQL, argsForCache...) session.cacheDelete(deleteSQL, argsForCache...)
} }
@ -212,7 +211,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} }
// handle after delete processors // handle after delete processors
if session.IsAutoCommit { if session.isAutoCommit {
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {
closure(bean) closure(bean)
} }

View File

@ -26,13 +26,27 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Delete(&UserinfoDelete{Uid: 1}) cnt, err = testEngine.Delete(&UserinfoDelete{Uid: user.Uid})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
user.Uid = 0 user.Uid = 0
user.IsMan = true user.IsMan = true
has, err := testEngine.Id(1).Get(&user) has, err := testEngine.ID(1).Get(&user)
assert.NoError(t, err)
assert.False(t, has)
cnt, err = testEngine.Insert(&user)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Where("id=?", user.Uid).Delete(&UserinfoDelete{})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
user.Uid = 0
user.IsMan = true
has, err = testEngine.ID(2).Get(&user)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)
} }
@ -68,16 +82,16 @@ func TestDeleted(t *testing.T) {
// Test normal Get() // Test normal Get()
record1 := &Deleted{} record1 := &Deleted{}
has, err := testEngine.Id(1).Get(record1) has, err := testEngine.ID(1).Get(record1)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
// Test Delete() with deleted // Test Delete() with deleted
affected, err := testEngine.Id(1).Delete(&Deleted{}) affected, err := testEngine.ID(1).Delete(&Deleted{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, affected) assert.EqualValues(t, 1, affected)
has, err = testEngine.Id(1).Get(&Deleted{}) has, err = testEngine.ID(1).Get(&Deleted{})
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)
@ -87,17 +101,17 @@ func TestDeleted(t *testing.T) {
assert.EqualValues(t, 2, len(records2)) assert.EqualValues(t, 2, len(records2))
// Test no rows affected after Delete() again. // Test no rows affected after Delete() again.
affected, err = testEngine.Id(1).Delete(&Deleted{}) affected, err = testEngine.ID(1).Delete(&Deleted{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, affected) assert.EqualValues(t, 0, affected)
// Deleted.DeletedAt must not be updated. // Deleted.DeletedAt must not be updated.
affected, err = testEngine.Id(2).Update(&Deleted{Name: "2", DeletedAt: time.Now()}) affected, err = testEngine.ID(2).Update(&Deleted{Name: "2", DeletedAt: time.Now()})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, affected) assert.EqualValues(t, 1, affected)
record2 := &Deleted{} record2 := &Deleted{}
has, err = testEngine.Id(2).Get(record2) has, err = testEngine.ID(2).Get(record2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, record2.DeletedAt.IsZero()) assert.True(t, record2.DeletedAt.IsZero())
@ -108,7 +122,7 @@ func TestDeleted(t *testing.T) {
assert.EqualValues(t, 3, len(unscopedRecords1)) assert.EqualValues(t, 3, len(unscopedRecords1))
// Delete() must really delete a record with Unscoped() // Delete() must really delete a record with Unscoped()
affected, err = testEngine.Unscoped().Id(1).Delete(&Deleted{}) affected, err = testEngine.Unscoped().ID(1).Delete(&Deleted{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, affected) assert.EqualValues(t, 1, affected)

View File

@ -23,11 +23,13 @@ const (
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct // are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct // map[int64]*Struct
func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
return session.find(rowsSlicePtr, condiBean...)
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return errors.New("needs a pointer to a slice or a map") return errors.New("needs a pointer to a slice or a map")
@ -36,11 +38,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var tp = tpStruct var tp = tpStruct
if session.Statement.RefTable == nil { if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
if err := session.Statement.setRefValue(pv.Elem()); err != nil { if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err return err
} }
} else { } else {
@ -48,7 +50,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType) pv := reflect.New(sliceElementType)
if err := session.Statement.setRefValue(pv.Elem()); err != nil { if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err return err
} }
} else { } else {
@ -56,31 +58,31 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} }
var table = session.Statement.RefTable var table = session.statement.RefTable
var addedTableName = (len(session.Statement.JoinStr) > 0) var addedTableName = (len(session.statement.JoinStr) > 0)
var autoCond builder.Cond var autoCond builder.Cond
if tp == tpStruct { if tp == tpStruct {
if !session.Statement.noAutoCondition && len(condiBean) > 0 { if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error var err error
autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given. // !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
// See https://github.com/go-xorm/xorm/issues/179 // See https://github.com/go-xorm/xorm/issues/179
if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
var colName = session.Engine.Quote(col.Name) var colName = session.engine.Quote(col.Name)
if addedTableName { if addedTableName {
var nm = session.Statement.TableName() var nm = session.statement.TableName()
if len(session.Statement.TableAlias) > 0 { if len(session.statement.TableAlias) > 0 {
nm = session.Statement.TableAlias nm = session.statement.TableAlias
} }
colName = session.Engine.Quote(nm) + "." + colName colName = session.engine.Quote(nm) + "." + colName
} }
autoCond = session.Engine.CondDeleted(colName) autoCond = session.engine.CondDeleted(colName)
} }
} }
} }
@ -88,27 +90,27 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error var err error
if session.Statement.RawSQL == "" { if session.statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
return ErrTableNotFound return ErrTableNotFound
} }
var columnStr = session.Statement.ColumnStr var columnStr = session.statement.ColumnStr
if len(session.Statement.selectStr) > 0 { if len(session.statement.selectStr) > 0 {
columnStr = session.Statement.selectStr columnStr = session.statement.selectStr
} else { } else {
if session.Statement.JoinStr == "" { if session.statement.JoinStr == "" {
if columnStr == "" { if columnStr == "" {
if session.Statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else { } else {
columnStr = session.Statement.genColumnStr() columnStr = session.statement.genColumnStr()
} }
} }
} else { } else {
if columnStr == "" { if columnStr == "" {
if session.Statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
} else { } else {
columnStr = "*" columnStr = "*"
} }
@ -119,13 +121,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} }
condSQL, condArgs, err := builder.ToSQL(session.Statement.cond.And(autoCond)) session.statement.cond = session.statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil { if err != nil {
return err return err
} }
args = append(session.Statement.joinArgs, condArgs...) args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL) sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
if err != nil { if err != nil {
return err return err
} }
@ -135,20 +138,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
args = append(args, args...) args = append(args, args...)
} }
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.statement.RawSQL
args = session.Statement.RawParams args = session.statement.RawParams
} }
if session.canCache() { if session.canCache() {
if cacher := session.Engine.getCacher2(table); cacher != nil && if cacher := session.engine.getCacher2(table); cacher != nil &&
!session.Statement.IsDistinct && !session.statement.IsDistinct &&
!session.Statement.unscoped { !session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return err return err
} }
err = nil // !nashtsai! reset err to nil for ErrCacheFailed err = nil // !nashtsai! reset err to nil for ErrCacheFailed
session.Engine.logger.Warn("Cache Find Failed") session.engine.logger.Warn("Cache Find Failed")
} }
} }
@ -156,21 +159,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
var rawRows *core.Rows rows, err := session.queryRows(sqlStr, args...)
var err error
session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
if err != nil { if err != nil {
return err return err
} }
defer rawRows.Close() defer rows.Close()
fields, err := rawRows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return err return err
} }
@ -240,24 +235,24 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface()) dataStruct := rValue(newValue.Interface())
tb, err := session.Engine.autoMapType(dataStruct) tb, err := session.engine.autoMapType(dataStruct)
if err != nil { if err != nil {
return err return err
} }
return session.rows2Beans(rawRows, fields, len(fields), tb, newElemFunc, containerValueSetFunc) return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
} }
for rawRows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
switch elemType.Kind() { switch elemType.Kind() {
case reflect.Slice: case reflect.Slice:
err = rawRows.ScanSlice(bean) err = rows.ScanSlice(bean)
case reflect.Map: case reflect.Map:
err = rawRows.ScanMap(bean) err = rows.ScanMap(bean)
default: default:
err = rawRows.Scan(bean) err = rows.Scan(bean)
} }
if err != nil { if err != nil {
@ -288,22 +283,22 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.Engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
} }
newsql := session.Statement.convertIDSQL(sqlStr) newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
tableName := session.Statement.TableName() tableName := session.statement.TableName()
table := session.Statement.RefTable table := session.statement.RefTable
cacher := session.Engine.getCacher2(table) cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
rows, err := session.DB().Query(newsql, args...) rows, err := session.NoCache().queryRows(newsql, args...)
if err != nil { if err != nil {
return err return err
} }
@ -314,7 +309,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for rows.Next() { for rows.Next() {
i++ i++
if i > 500 { if i > 500 {
session.Engine.logger.Debug("[cacheFind] ids length > 500, no cache") session.engine.logger.Debug("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed return ErrCacheFailed
} }
var res = make([]string, len(table.PrimaryKeys)) var res = make([]string, len(table.PrimaryKeys))
@ -324,7 +319,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
pk[i], err = session.Engine.idTypeAssertion(col, res[i]) pk[i], err = session.engine.idTypeAssertion(col, res[i])
if err != nil { if err != nil {
return err return err
} }
@ -333,13 +328,13 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ids = append(ids, pk) ids = append(ids, pk)
} }
session.Engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args) session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args) err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
session.Engine.logger.Debug("[cacheFind] cache hit sql:", newsql, args) session.engine.logger.Debug("[cacheFind] cache hit sql:", newsql, args)
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
@ -358,16 +353,16 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ides = append(ides, id) ides = append(ides, id)
ididxes[sid] = idx ididxes[sid] = idx
} else { } else {
session.Engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean) session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean)
pk := session.Engine.IdOf(bean) pk := session.engine.IdOf(bean)
xid, err := pk.ToString() xid, err := pk.ToString()
if err != nil { if err != nil {
return err return err
} }
if sid != xid { if sid != xid {
session.Engine.logger.Error("[cacheFind] error cache", xid, sid, bean) session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean)
return ErrCacheFailed return ErrCacheFailed
} }
temps[idx] = bean temps[idx] = bean
@ -375,9 +370,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
if len(ides) > 0 { if len(ides) > 0 {
newSession := session.Engine.NewSession()
defer newSession.Close()
slices := reflect.New(reflect.SliceOf(t)) slices := reflect.New(reflect.SliceOf(t))
beans := slices.Interface() beans := slices.Interface()
@ -387,18 +379,18 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ff = append(ff, ie[0]) ff = append(ff, ie[0])
} }
newSession.In("`"+table.PrimaryKeys[0]+"`", ff...) session.In("`"+table.PrimaryKeys[0]+"`", ff...)
} else { } else {
for _, ie := range ides { for _, ie := range ides {
cond := builder.NewCond() cond := builder.NewCond()
for i, name := range table.PrimaryKeys { for i, name := range table.PrimaryKeys {
cond = cond.And(builder.Eq{"`" + name + "`": ie[i]}) cond = cond.And(builder.Eq{"`" + name + "`": ie[i]})
} }
newSession.Or(cond) session.Or(cond)
} }
} }
err = newSession.NoCache().Find(beans) err = session.NoCache().find(beans)
if err != nil { if err != nil {
return err return err
} }
@ -409,7 +401,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
id, err := session.Engine.idOfV(rv) id, err := session.engine.idOfV(rv)
if err != nil { if err != nil {
return err return err
} }
@ -420,7 +412,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
bean := rv.Interface() bean := rv.Interface()
temps[ididxes[sid]] = bean temps[ididxes[sid]] = bean
session.Engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps) session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps)
cacher.PutBean(tableName, sid, bean) cacher.PutBean(tableName, sid, bean)
} }
} }
@ -428,7 +420,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for j := 0; j < len(temps); j++ { for j := 0; j < len(temps); j++ {
bean := temps[j] bean := temps[j]
if bean == nil { if bean == nil {
session.Engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps) session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps)
// return errors.New("cache error") // !nashtsai! no need to return error, but continue instead // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead
continue continue
} }

View File

@ -15,18 +15,22 @@ import (
// Get retrieve one record from database, bean's non-empty fields // Get retrieve one record from database, bean's non-empty fields
// will be as conditions // will be as conditions
func (session *Session) Get(bean interface{}) (bool, error) { func (session *Session) Get(bean interface{}) (bool, error) {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
return session.get(bean)
}
func (session *Session) get(bean interface{}) (bool, error) {
beanValue := reflect.ValueOf(bean) beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr { if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer") return false, errors.New("needs a pointer to a value")
} else if beanValue.Elem().Kind() == reflect.Ptr {
return false, errors.New("a pointer to a pointer is not allowed")
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
if err := session.Statement.setRefValue(beanValue.Elem()); err != nil { if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
return false, err return false, err
} }
} }
@ -35,23 +39,23 @@ func (session *Session) Get(bean interface{}) (bool, error) {
var args []interface{} var args []interface{}
var err error var err error
if session.Statement.RawSQL == "" { if session.statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
return false, ErrTableNotFound return false, ErrTableNotFound
} }
session.Statement.Limit(1) session.statement.Limit(1)
sqlStr, args, err = session.Statement.genGetSQL(bean) sqlStr, args, err = session.statement.genGetSQL(bean)
if err != nil { if err != nil {
return false, err return false, err
} }
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.statement.RawSQL
args = session.Statement.RawParams args = session.statement.RawParams
} }
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && if cacher := session.engine.getCacher2(session.statement.RefTable); cacher != nil &&
!session.Statement.unscoped { !session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return has, err return has, err
@ -63,50 +67,42 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
session.queryPreprocess(&sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
var rawRows *core.Rows
var err error
if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
if err != nil { if err != nil {
return false, err return false, err
} }
defer rows.Close()
defer rawRows.Close() if !rows.Next() {
if !rawRows.Next() {
return false, nil return false, nil
} }
switch beanKind { switch beanKind {
case reflect.Struct: case reflect.Struct:
fields, err := rawRows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
// WARN: Alougth rawRows return true, but get fields failed // WARN: Alougth rows return true, but get fields failed
return true, err return true, err
} }
dataStruct := rValue(bean) dataStruct := rValue(bean)
if err := session.Statement.setRefValue(dataStruct); err != nil { if err := session.statement.setRefValue(dataStruct); err != nil {
return false, err return false, err
} }
scanResults, err := session.row2Slice(rawRows, fields, len(fields), bean) scanResults, err := session.row2Slice(rows, fields, len(fields), bean)
if err != nil { if err != nil {
return false, err return false, err
} }
rawRows.Close() // close it before covert data
rows.Close()
_, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.Statement.RefTable) _, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.statement.RefTable)
case reflect.Slice: case reflect.Slice:
err = rawRows.ScanSlice(bean) err = rows.ScanSlice(bean)
case reflect.Map: case reflect.Map:
err = rawRows.ScanMap(bean) err = rows.ScanMap(bean)
default: default:
err = rawRows.Scan(bean) err = rows.Scan(bean)
} }
return true, err return true, err
@ -118,22 +114,22 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed return false, ErrCacheFailed
} }
for _, filter := range session.Engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
} }
newsql := session.Statement.convertIDSQL(sqlStr) newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {
return false, ErrCacheFailed return false, ErrCacheFailed
} }
cacher := session.Engine.getCacher2(session.Statement.RefTable) cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.Statement.TableName() tableName := session.statement.TableName()
session.Engine.logger.Debug("[cacheGet] find sql:", newsql, args) session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
table := session.Statement.RefTable table := session.statement.RefTable
if err != nil { if err != nil {
var res = make([]string, len(table.PrimaryKeys)) var res = make([]string, len(table.PrimaryKeys))
rows, err := session.DB().Query(newsql, args...) rows, err := session.NoCache().queryRows(newsql, args...)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -164,19 +160,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
ids = []core.PK{pk} ids = []core.PK{pk}
session.Engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args) err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return false, err return false, err
} }
} else { } else {
session.Engine.logger.Debug("[cacheGet] cache hit sql:", newsql) session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql)
} }
if len(ids) > 0 { if len(ids) > 0 {
structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0] id := ids[0]
session.Engine.logger.Debug("[cacheGet] get bean:", tableName, id) session.engine.logger.Debug("[cacheGet] get bean:", tableName, id)
sid, err := id.ToString() sid, err := id.ToString()
if err != nil { if err != nil {
return false, err return false, err
@ -189,10 +185,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return has, err return has, err
} }
session.Engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
cacher.PutBean(tableName, sid, cacheBean) cacher.PutBean(tableName, sid, cacheBean)
} else { } else {
session.Engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
has = true has = true
} }
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))

View File

@ -71,15 +71,18 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "28", valuesString["age"]) assert.Equal(t, "28", valuesString["age"])
assert.Equal(t, "1.5", valuesString["money"]) assert.Equal(t, "1.5", valuesString["money"])
var valuesInter = make(map[string]interface{}) // for mymysql driver, interface{} will be []byte, so ignore it currently
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) if testEngine.dialect.DriverName() != "mymysql" {
assert.NoError(t, err) var valuesInter = make(map[string]interface{})
assert.Equal(t, true, has) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
assert.Equal(t, 5, len(valuesInter)) assert.NoError(t, err)
assert.EqualValues(t, 1, valuesInter["id"]) assert.Equal(t, true, has)
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"])) assert.Equal(t, 5, len(valuesInter))
assert.EqualValues(t, 28, valuesInter["age"]) assert.EqualValues(t, 1, valuesInter["id"])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"]))
assert.EqualValues(t, 28, valuesInter["age"])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
}
var valuesSliceString = make([]string, 5) var valuesSliceString = make([]string, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceString) has, err = testEngine.Table("get_var").Get(&valuesSliceString)
@ -171,3 +174,23 @@ func TestGetSlice(t *testing.T) {
assert.False(t, has) assert.False(t, has)
assert.Error(t, err) assert.Error(t, err)
} }
func TestGetError(t *testing.T) {
assert.NoError(t, prepareEngine())
type GetError struct {
Uid int `xorm:"pk autoincr"`
IsMan bool
}
assertSync(t, new(GetError))
var info = new(GetError)
has, err := testEngine.Get(&info)
assert.False(t, has)
assert.Error(t, err)
has, err = testEngine.Get(info)
assert.False(t, has)
assert.NoError(t, err)
}

View File

@ -19,17 +19,16 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
var affected int64 var affected int64
var err error var err error
if session.IsAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
defer session.resetStatement()
for _, bean := range beans { for _, bean := range beans {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice { if sliceValue.Kind() == reflect.Slice {
size := sliceValue.Len() size := sliceValue.Len()
if size > 0 { if size > 0 {
if session.Engine.SupportInsertMany() { if session.engine.SupportInsertMany() {
cnt, err := session.innerInsertMulti(bean) cnt, err := session.innerInsertMulti(bean)
if err != nil { if err != nil {
return affected, err return affected, err
@ -67,15 +66,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice") return 0, errors.New("could not insert a empty slice")
} }
if err := session.Statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
return 0, err return 0, err
} }
if len(session.Statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
table := session.Statement.RefTable table := session.statement.RefTable
size := sliceValue.Len() size := sliceValue.Len()
var colNames []string var colNames []string
@ -116,18 +115,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue continue
} }
} }
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name) val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val) args = append(args, val)
var colName = col.Name var colName = col.Name
@ -135,7 +134,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.Statement.checkVersion { } else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1) args = append(args, 1)
var colName = col.Name var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -171,18 +170,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue continue
} }
} }
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.Engine.NowTime2(col.SQLType.Name) val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val) args = append(args, val)
var colName = col.Name var colName = col.Name
@ -190,7 +189,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.Statement.checkVersion { } else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1) args = append(args, 1)
var colName = col.Name var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -214,25 +213,25 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
var statement string var statement string
if session.Engine.dialect.DBType() == core.ORACLE { if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
session.Engine.Quote(session.Statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.Engine.QuoteStr()) session.engine.QuoteStr())
statement = fmt.Sprintf(sql, statement = fmt.Sprintf(sql,
session.Engine.Quote(session.Statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp)) strings.Join(colMultiPlaces, temp))
} else { } else {
statement = fmt.Sprintf(sql, statement = fmt.Sprintf(sql,
session.Engine.Quote(session.Statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),(")) strings.Join(colMultiPlaces, "),("))
} }
res, err := session.exec(statement, args...) res, err := session.exec(statement, args...)
@ -240,8 +239,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err return 0, err
} }
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.statement.TableName())
} }
lenAfterClosures := len(session.afterClosures) lenAfterClosures := len(session.afterClosures)
@ -249,7 +248,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface() elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor // handle AfterInsertProcessor
if session.IsAutoCommit { if session.isAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {
closure(elemValue) closure(elemValue)
@ -280,8 +279,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
// InsertMulti insert multiple records // InsertMulti insert multiple records
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
@ -299,14 +297,14 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
} }
func (session *Session) innerInsert(bean interface{}) (int64, error) { func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.Statement.setRefValue(rValue(bean)); err != nil { if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err return 0, err
} }
if len(session.Statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
table := session.Statement.RefTable table := session.statement.RefTable
// handle BeforeInsertProcessor // handle BeforeInsertProcessor
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
@ -318,12 +316,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
processor.BeforeInsert() processor.BeforeInsert()
} }
// -- // --
colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false) colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
// insert expr columns, override if exists // insert expr columns, override if exists
exprColumns := session.Statement.getExpr() exprColumns := session.statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns)) exprColVals := make([]string, 0, len(exprColumns))
for _, v := range exprColumns { for _, v := range exprColumns {
// remove the expr columns // remove the expr columns
@ -351,21 +349,21 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var sqlStr string var sqlStr string
if len(colPlaces) > 0 { if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
session.Engine.Quote(session.Statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.Engine.Quote(", ")), strings.Join(colNames, session.engine.Quote(", ")),
session.Engine.QuoteStr(), session.engine.QuoteStr(),
colPlaces) colPlaces)
} else { } else {
if session.Engine.dialect.DBType() == core.MYSQL { if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.Engine.Quote(session.Statement.TableName())) sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(session.statement.TableName()))
} else { } else {
sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.Engine.Quote(session.Statement.TableName())) sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(session.statement.TableName()))
} }
} }
handleAfterInsertProcessorFunc := func(bean interface{}) { handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.IsAutoCommit { if session.isAutoCommit {
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {
closure(bean) closure(bean)
} }
@ -394,22 +392,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.query("select seq_atable.currval from dual", args...) res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
handleAfterInsertProcessorFunc(bean) handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.statement.TableName())
} }
if table.Version != "" && session.Statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() { } else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) verValue.SetInt(1)
} }
@ -427,7 +425,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} }
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -437,24 +435,24 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type())) aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil return 1, nil
} else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { } else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != "" //assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
res, err := session.query(sqlStr, args...) res, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
handleAfterInsertProcessorFunc(bean) handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.statement.TableName())
} }
if table.Version != "" && session.Statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() { } else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) verValue.SetInt(1)
} }
@ -472,7 +470,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} }
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -490,14 +488,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.statement.TableName())
} }
if table.Version != "" && session.Statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() { } else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) verValue.SetInt(1)
} }
@ -515,7 +513,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue, err := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} }
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@ -532,8 +530,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// The in parameter bean must a struct or a point to struct. The return // The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error // parameter is inserted and error
func (session *Session) InsertOne(bean interface{}) (int64, error) { func (session *Session) InsertOne(bean interface{}) (int64, error) {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
@ -541,15 +538,15 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
} }
func (session *Session) cacheInsert(tables ...string) error { func (session *Session) cacheInsert(tables ...string) error {
if session.Statement.RefTable == nil { if session.statement.RefTable == nil {
return ErrCacheFailed return ErrCacheFailed
} }
table := session.Statement.RefTable table := session.statement.RefTable
cacher := session.Engine.getCacher2(table) cacher := session.engine.getCacher2(table)
for _, t := range tables { for _, t := range tables {
session.Engine.logger.Debug("[cache] clear sql:", t) session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t) cacher.ClearIds(t)
} }

View File

@ -19,6 +19,10 @@ func (session *Session) Rows(bean interface{}) (*Rows, error) {
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct // are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct // map[int64]*Struct
func (session *Session) Iterate(bean interface{}, fun IterFunc) error { func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
if session.isAutoClose {
defer session.Close()
}
rows, err := session.Rows(bean) rows, err := session.Rows(bean)
if err != nil { if err != nil {
return err return err

View File

@ -127,7 +127,7 @@ func TestIntId(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&IntId{}) cnt, err = testEngine.ID(bean.Id).Delete(&IntId{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -202,7 +202,7 @@ func TestInt16Id(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&Int16Id{}) cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -277,7 +277,7 @@ func TestInt32Id(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&Int32Id{}) cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -366,7 +366,7 @@ func TestUintId(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&UintId{}) cnt, err = testEngine.ID(bean.Id).Delete(&UintId{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -441,7 +441,7 @@ func TestUint16Id(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&Uint16Id{}) cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -516,7 +516,7 @@ func TestUint32Id(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&Uint32Id{}) cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -604,7 +604,7 @@ func TestUint64Id(t *testing.T) {
panic(errors.New("should be equal")) panic(errors.New("should be equal"))
} }
cnt, err = testEngine.Id(bean.Id).Delete(&Uint64Id{}) cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -679,7 +679,7 @@ func TestStringPK(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(bean.Id).Delete(&StringPK{}) cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -725,7 +725,7 @@ func TestCompositeKey(t *testing.T) {
} }
var compositeKeyVal CompositeKey var compositeKeyVal CompositeKey
has, err := testEngine.Id(core.PK{11, 22}).Get(&compositeKeyVal) has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -734,7 +734,7 @@ func TestCompositeKey(t *testing.T) {
var compositeKeyVal2 CompositeKey var compositeKeyVal2 CompositeKey
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = testEngine.Id(&core.PK{11, 22}).Get(&compositeKeyVal2) has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -781,14 +781,14 @@ func TestCompositeKey(t *testing.T) {
} }
compositeKeyVal = CompositeKey{UpdateStr: "test1"} compositeKeyVal = CompositeKey{UpdateStr: "test1"}
cnt, err = testEngine.Id(core.PK{11, 22}).Update(&compositeKeyVal) cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}")) t.Error(errors.New("can't update CompositeKey{11, 22}"))
} }
cnt, err = testEngine.Id(core.PK{11, 22}).Delete(&CompositeKey{}) cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -832,7 +832,7 @@ func TestCompositeKey2(t *testing.T) {
} }
var user User var user User
has, err := testEngine.Id(core.PK{"11", 22}).Get(&user) has, err := testEngine.ID(core.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -840,7 +840,7 @@ func TestCompositeKey2(t *testing.T) {
} }
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = testEngine.Id(&core.PK{"11", 22}).Get(&user) has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -848,14 +848,14 @@ func TestCompositeKey2(t *testing.T) {
} }
user = User{NickName: "test1"} user = User{NickName: "test1"}
cnt, err = testEngine.Id(core.PK{"11", 22}).Update(&user) cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}")) t.Error(errors.New("can't update User{11, 22}"))
} }
cnt, err = testEngine.Id(core.PK{"11", 22}).Delete(&User{}) cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -900,7 +900,7 @@ func TestCompositeKey3(t *testing.T) {
} }
var user UserPK2 var user UserPK2
has, err := testEngine.Id(core.PK{"11", 22}).Get(&user) has, err := testEngine.ID(core.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -908,7 +908,7 @@ func TestCompositeKey3(t *testing.T) {
} }
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = testEngine.Id(&core.PK{"11", 22}).Get(&user) has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -916,14 +916,14 @@ func TestCompositeKey3(t *testing.T) {
} }
user = UserPK2{NickName: "test1"} user = UserPK2{NickName: "test1"}
cnt, err = testEngine.Id(core.PK{"11", 22}).Update(&user) cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}")) t.Error(errors.New("can't update User{11, 22}"))
} }
cnt, err = testEngine.Id(core.PK{"11", 22}).Delete(&UserPK2{}) cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -1007,7 +1007,7 @@ func TestMyIntId(t *testing.T) {
panic(errors.New("should be equal")) panic(errors.New("should be equal"))
} }
cnt, err = testEngine.Id(bean.ID).Delete(&MyIntPK{}) cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1095,7 +1095,7 @@ func TestMyStringId(t *testing.T) {
panic(errors.New("should be equal")) panic(errors.New("should be equal"))
} }
cnt, err = testEngine.Id(bean.ID).Delete(&MyStringPK{}) cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)

View File

@ -6,77 +6,77 @@ package xorm
import ( import (
"database/sql" "database/sql"
"fmt"
"reflect" "reflect"
"strconv"
"time" "time"
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func (session *Session) query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
session.queryPreprocess(&sqlStr, paramStr...) for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
if session.IsAutoCommit {
return session.innerQuery2(sqlStr, paramStr...)
} }
return session.txQuery(session.Tx, sqlStr, paramStr...)
session.lastSQL = *sqlStr
session.lastSQLArgs = paramStr
} }
func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string][]byte, error) { func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) {
rows, err := tx.Query(sqlStr, params...) defer session.resetStatement()
if err != nil {
return nil, err session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
} }
defer rows.Close()
return rows2maps(rows) if session.isAutoCommit {
} if session.prepareStmt {
// don't clear stmt since session will cache them
func (session *Session) innerQuery(sqlStr string, params ...interface{}) (*core.Stmt, *core.Rows, error) {
var callback func() (*core.Stmt, *core.Rows, error)
if session.prepareStmt {
callback = func() (*core.Stmt, *core.Rows, error) {
stmt, err := session.doPrepare(sqlStr) stmt, err := session.doPrepare(sqlStr)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
rows, err := stmt.Query(params...)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
} else {
callback = func() (*core.Stmt, *core.Rows, error) {
rows, err := session.DB().Query(sqlStr, params...)
if err != nil {
return nil, nil, err
}
return nil, rows, err
}
}
stmt, rows, err := session.Engine.logSQLQueryTime(sqlStr, params, callback)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { rows, err := stmt.Query(args...)
fields, err := rows.Columns() if err != nil {
if err != nil { return nil, err
return nil, err }
} return rows, nil
for rows.Next() { }
result, err := row2map(rows, fields)
rows, err := session.DB().Query(sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resultsSlice = append(resultsSlice, result) return rows, nil
} }
return resultsSlice, nil rows, err := session.tx.Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
return core.NewRow(session.queryRows(sqlStr, args...))
} }
func value2Bytes(rawValue *reflect.Value) (data []byte, err error) { func value2Bytes(rawValue *reflect.Value) (data []byte, err error) {
@ -104,7 +104,7 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore //if row is null then ignore
if rawValue.Interface() == nil { if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue) result[key] = []byte{}
continue continue
} }
@ -117,34 +117,13 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
return result, nil return result, nil
} }
func (session *Session) innerQuery2(sqlStr string, params ...interface{}) ([]map[string][]byte, error) { func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
_, rows, err := session.innerQuery(sqlStr, params...)
if rows != nil {
defer rows.Close()
}
if err != nil {
return nil, err
}
return rows2maps(rows)
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
return session.query(sqlStr, paramStr...)
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for rows.Next() { for rows.Next() {
result, err := row2mapStr(rows, fields) result, err := row2map(rows, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -154,122 +133,45 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error)
return resultsSlice, nil return resultsSlice, nil
} }
func reflect2value(rawValue *reflect.Value) (str string, err error) { func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
aa := reflect.TypeOf((*rawValue).Interface()) rows, err := session.queryRows(sqlStr, args...)
vv := reflect.ValueOf((*rawValue).Interface())
switch aa.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
case reflect.String:
str = vv.String()
case reflect.Array, reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
data := rawValue.Interface().([]byte)
str = string(data)
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(core.TimeType) {
str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
case reflect.Complex128, reflect.Complex64:
str = fmt.Sprintf("%v", vv.Complex())
/* TODO: unsupported types below
case reflect.Map:
case reflect.Ptr:
case reflect.Uintptr:
case reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
return
}
func value2String(rawValue *reflect.Value) (data string, err error) {
data, err = reflect2value(rawValue)
if err != nil {
return
}
return
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string]string, error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
return rows2Strings(rows) return rows2maps(rows)
} }
func query2(db *core.DB, sqlStr string, params ...interface{}) ([]map[string]string, error) { func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
rows, err := db.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
defer session.resetStatement() defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit { if session.engine.showSQL {
return query2(session.DB(), sqlStr, args...) if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
if !session.isAutoCommit {
return session.tx.Exec(sqlStr, args...)
} }
return txQuery2(session.Tx, sqlStr, args...)
}
// Execute sql
func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) {
if session.prepareStmt { if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr) stmt, err := session.doPrepare(sqlStr)
if err != nil { if err != nil {
@ -286,33 +188,9 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul
return session.DB().Exec(sqlStr, args...) return session.DB().Exec(sqlStr, args...)
} }
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.Engine.dialect.Filters() {
// TODO: for table name, it's no need to RefTable
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
session.saveLastSQL(sqlStr, args...)
return session.Engine.logSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
if session.IsAutoCommit {
// FIXME: oci8 can not auto commit (github.com/mattn/go-oci8)
if session.Engine.dialect.DBType() == core.ORACLE {
session.Begin()
r, err := session.Tx.Exec(sqlStr, args...)
session.Commit()
return r, err
}
return session.innerExec(sqlStr, args...)
}
return session.Tx.Exec(sqlStr, args...)
})
}
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }

View File

@ -7,42 +7,10 @@ package xorm
import ( import (
"strconv" "strconv"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestQueryString(t *testing.T) {
assert.NoError(t, prepareEngine())
type GetVar struct {
Id int64 `xorm:"autoincr pk"`
Msg string `xorm:"varchar(255)"`
Age int
Money float32
Created time.Time `xorm:"created"`
}
assert.NoError(t, testEngine.Sync2(new(GetVar)))
var data = GetVar{
Msg: "hi",
Age: 28,
Money: 1.5,
}
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
records, err := testEngine.QueryString("select * from get_var")
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
assert.Equal(t, "1", records[0]["id"])
assert.Equal(t, "hi", records[0]["msg"])
assert.Equal(t, "28", records[0]["age"])
assert.Equal(t, "1.5", records[0]["money"])
}
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())

View File

@ -16,42 +16,50 @@ import (
// Ping test if database is ok // Ping test if database is ok
func (session *Session) Ping() error { func (session *Session) Ping() error {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().Ping() return session.DB().Ping()
} }
// CreateTable create a table according a bean // CreateTable create a table according a bean
func (session *Session) CreateTable(bean interface{}) error { func (session *Session) CreateTable(bean interface{}) error {
v := rValue(bean) if session.isAutoClose {
if err := session.Statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
return session.createOneTable() return session.createTable(bean)
}
func (session *Session) createTable(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqlStr := session.statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
} }
// CreateIndexes create indexes // CreateIndexes create indexes
func (session *Session) CreateIndexes(bean interface{}) error { func (session *Session) CreateIndexes(bean interface{}) error {
v := rValue(bean) if session.isAutoClose {
if err := session.Statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
sqls := session.Statement.genIndexSQL() return session.createIndexes(bean)
}
func (session *Session) createIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genIndexSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
@ -63,17 +71,19 @@ func (session *Session) CreateIndexes(bean interface{}) error {
// CreateUniques create uniques // CreateUniques create uniques
func (session *Session) CreateUniques(bean interface{}) error { func (session *Session) CreateUniques(bean interface{}) error {
if session.isAutoClose {
defer session.Close()
}
return session.createUniques(bean)
}
func (session *Session) createUniques(bean interface{}) error {
v := rValue(bean) v := rValue(bean)
if err := session.Statement.setRefValue(v); err != nil { if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
defer session.resetStatement() sqls := session.statement.genUniqueSQL()
if session.IsAutoClose {
defer session.Close()
}
sqls := session.Statement.genUniqueSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
@ -83,25 +93,22 @@ func (session *Session) CreateUniques(bean interface{}) error {
return nil return nil
} }
func (session *Session) createOneTable() error {
sqlStr := session.Statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
}
// DropIndexes drop indexes // DropIndexes drop indexes
func (session *Session) DropIndexes(bean interface{}) error { func (session *Session) DropIndexes(bean interface{}) error {
v := rValue(bean) if session.isAutoClose {
if err := session.Statement.setRefValue(v); err != nil {
return err
}
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
sqls := session.Statement.genDelIndexSQL() return session.dropIndexes(bean)
}
func (session *Session) dropIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err
}
sqls := session.statement.genDelIndexSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
@ -113,15 +120,23 @@ func (session *Session) DropIndexes(bean interface{}) error {
// DropTable drop table will drop table if exist, if drop failed, it will return error // DropTable drop table will drop table if exist, if drop failed, it will return error
func (session *Session) DropTable(beanOrTableName interface{}) error { func (session *Session) DropTable(beanOrTableName interface{}) error {
tableName, err := session.Engine.tableName(beanOrTableName) if session.isAutoClose {
defer session.Close()
}
return session.dropTable(beanOrTableName)
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil { if err != nil {
return err return err
} }
var needDrop = true var needDrop = true
if !session.Engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return err return err
} }
@ -129,7 +144,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
} }
if needDrop { if needDrop {
sqlStr := session.Engine.Dialect().DropTableSql(tableName) sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err = session.exec(sqlStr) _, err = session.exec(sqlStr)
return err return err
} }
@ -138,7 +153,11 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
// IsTableExist if a table is exist // IsTableExist if a table is exist
func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) { func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
tableName, err := session.Engine.tableName(beanOrTableName) if session.isAutoClose {
defer session.Close()
}
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -147,12 +166,8 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
} }
func (session *Session) isTableExist(tableName string) (bool, error) { func (session *Session) isTableExist(tableName string) (bool, error) {
defer session.resetStatement() sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
if session.IsAutoClose { results, err := session.queryBytes(sqlStr, args...)
defer session.Close()
}
sqlStr, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
return len(results) > 0, err return len(results) > 0, err
} }
@ -162,6 +177,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
t := v.Type() t := v.Type()
if t.Kind() == reflect.String { if t.Kind() == reflect.String {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(bean.(string)) return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean) rows, err := session.Count(bean)
@ -171,15 +189,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
} }
func (session *Session) isTableEmpty(tableName string) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var total int64 var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName)) sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
err := session.DB().QueryRow(sqlStr).Scan(&total) err := session.queryRow(sqlStr).Scan(&total)
session.saveLastSQL(sqlStr)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
@ -192,12 +204,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
// find if index is exist according cols // find if index is exist according cols
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
defer session.resetStatement() indexes, err := session.engine.dialect.GetIndexes(tableName)
if session.IsAutoClose {
defer session.Close()
}
indexes, err := session.Engine.dialect.GetIndexes(tableName)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -214,43 +221,34 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
} }
func (session *Session) addColumn(colName string) error { func (session *Session) addColumn(colName string) error {
defer session.resetStatement() col := session.statement.RefTable.GetColumn(colName)
if session.IsAutoClose { sql, args := session.statement.genAddColumnStr(col)
defer session.Close()
}
col := session.Statement.RefTable.GetColumn(colName)
sql, args := session.Statement.genAddColumnStr(col)
_, err := session.exec(sql, args...) _, err := session.exec(sql, args...)
return err return err
} }
func (session *Session) addIndex(tableName, idxName string) error { func (session *Session) addIndex(tableName, idxName string) error {
defer session.resetStatement() index := session.statement.RefTable.Indexes[idxName]
if session.IsAutoClose { sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
defer session.Close()
}
index := session.Statement.RefTable.Indexes[idxName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
func (session *Session) addUnique(tableName, uqeName string) error { func (session *Session) addUnique(tableName, uqeName string) error {
defer session.resetStatement() index := session.statement.RefTable.Indexes[uqeName]
if session.IsAutoClose { sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
defer session.Close()
}
index := session.Statement.RefTable.Indexes[uqeName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
// Sync2 synchronize structs to database tables // Sync2 synchronize structs to database tables
func (session *Session) Sync2(beans ...interface{}) error { func (session *Session) Sync2(beans ...interface{}) error {
engine := session.Engine engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.DBMetas() tables, err := engine.DBMetas()
if err != nil { if err != nil {
@ -277,17 +275,17 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
if oriTable == nil { if oriTable == nil {
err = session.StoreEngine(session.Statement.StoreEngine).CreateTable(bean) err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil { if err != nil {
return err return err
} }
err = session.CreateUniques(bean) err = session.createUniques(bean)
if err != nil { if err != nil {
return err return err
} }
err = session.CreateIndexes(bean) err = session.createIndexes(bean)
if err != nil { if err != nil {
return err return err
} }
@ -312,7 +310,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
engine.dialect.DBType() == core.POSTGRES { engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType) tbName, col.Name, curType, expectedType)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} else { } else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType) tbName, col.Name, curType, expectedType)
@ -322,7 +320,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length) tbName, col.Name, oriCol.Length, col.Length)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} }
} }
} else { } else {
@ -336,7 +334,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length) tbName, col.Name, oriCol.Length, col.Length)
_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} }
} }
} }
@ -349,10 +347,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbName, col.Name, oriCol.Nullable, col.Nullable) tbName, col.Name, oriCol.Nullable, col.Nullable)
} }
} else { } else {
session := engine.NewSession() session.statement.RefTable = table
session.Statement.RefTable = table session.statement.tableName = tbName
session.Statement.tableName = tbName
defer session.Close()
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
} }
if err != nil { if err != nil {
@ -376,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil { if oriIndex != nil {
if oriIndex.Type != index.Type { if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex) sql := engine.dialect.DropIndexSql(tbName, oriIndex)
_, err = engine.Exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
} }
@ -392,7 +388,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2) sql := engine.dialect.DropIndexSql(tbName, index2)
_, err = engine.Exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
} }
@ -401,16 +397,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
session := engine.NewSession() session.statement.RefTable = table
session.Statement.RefTable = table session.statement.tableName = tbName
session.Statement.tableName = tbName
defer session.Close()
err = session.addUnique(tbName, name) err = session.addUnique(tbName, name)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
session := engine.NewSession() session.statement.RefTable = table
session.Statement.RefTable = table session.statement.tableName = tbName
session.Statement.tableName = tbName
defer session.Close()
err = session.addIndex(tbName, name) err = session.addIndex(tbName, name)
} }
if err != nil { if err != nil {

View File

@ -1,187 +0,0 @@
// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import "database/sql"
// Count counts the records. bean's non-empty fields
// are conditions.
func (session *Session) Count(bean ...interface{}) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if session.Statement.RawSQL == "" {
if len(bean) == 0 {
return 0, ErrTableNotFound
}
sqlStr, args, err = session.Statement.genCountSQL(bean[0])
if err != nil {
return 0, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var total int64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&total)
}
if err == sql.ErrNoRows || err == nil {
return total, nil
}
return 0, err
}
// Sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) Sum(bean interface{}, columnName string) (float64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
if err != nil {
return 0, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res float64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return 0, err
}
// SumInt call sum some column. bean's non-empty fields are conditions.
func (session *Session) SumInt(bean interface{}, columnName string) (int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
if err != nil {
return 0, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res int64
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).Scan(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return 0, err
}
// Sums call sum some columns. bean's non-empty fields are conditions.
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
if err != nil {
return nil, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res = make([]float64, len(columnNames), len(columnNames))
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return nil, err
}
// SumsInt sum specify columns and return as []int64 instead of []float64
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
defer session.resetStatement()
if session.IsAutoClose {
defer session.Close()
}
var sqlStr string
var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 {
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
if err != nil {
return nil, err
}
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var res = make([]int64, len(columnNames), len(columnNames))
if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
} else {
err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
}
if err == sql.ErrNoRows || err == nil {
return res, nil
}
return nil, err
}

View File

@ -1,152 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"strconv"
"testing"
"github.com/go-xorm/builder"
"github.com/stretchr/testify/assert"
)
func isFloatEq(i, j float64, precision int) bool {
return fmt.Sprintf("%."+strconv.Itoa(precision)+"f", i) == fmt.Sprintf("%."+strconv.Itoa(precision)+"f", j)
}
func TestSum(t *testing.T) {
assert.NoError(t, prepareEngine())
type SumStruct struct {
Int int
Float float32
}
var (
cases = []SumStruct{
{1, 6.2},
{2, 5.3},
{92, -0.2},
}
)
var i int
var f float32
for _, v := range cases {
i += v.Int
f += v.Float
}
assert.NoError(t, testEngine.Sync2(new(SumStruct)))
cnt, err := testEngine.Insert(cases)
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)
colInt := testEngine.ColumnMapper.Obj2Table("Int")
colFloat := testEngine.ColumnMapper.Obj2Table("Float")
sumInt, err := testEngine.Sum(new(SumStruct), colInt)
assert.NoError(t, err)
assert.EqualValues(t, int(sumInt), i)
sumFloat, err := testEngine.Sum(new(SumStruct), colFloat)
assert.NoError(t, err)
assert.Condition(t, func() bool {
return isFloatEq(sumFloat, float64(f), 2)
})
sums, err := testEngine.Sums(new(SumStruct), colInt, colFloat)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(sums))
assert.EqualValues(t, i, int(sums[0]))
assert.Condition(t, func() bool {
return isFloatEq(sums[1], float64(f), 2)
})
sumsInt, err := testEngine.SumsInt(new(SumStruct), colInt)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(sumsInt))
assert.EqualValues(t, i, int(sumsInt[0]))
}
func TestSumCustomColumn(t *testing.T) {
assert.NoError(t, prepareEngine())
type SumStruct struct {
Int int
Float float32
}
var (
cases = []SumStruct{
{1, 6.2},
{2, 5.3},
{92, -0.2},
}
)
assert.NoError(t, testEngine.Sync2(new(SumStruct)))
cnt, err := testEngine.Insert(cases)
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)
sumInt, err := testEngine.Sum(new(SumStruct),
"CASE WHEN `int` <= 2 THEN `int` ELSE 0 END")
assert.NoError(t, err)
assert.EqualValues(t, 3, int(sumInt))
}
func TestCount(t *testing.T) {
assert.NoError(t, prepareEngine())
type UserinfoCount struct {
Departname string
}
assert.NoError(t, testEngine.Sync2(new(UserinfoCount)))
colName := testEngine.ColumnMapper.Obj2Table("Departname")
var cond builder.Cond = builder.Eq{
"`" + colName + "`": "dev",
}
total, err := testEngine.Where(cond).Count(new(UserinfoCount))
assert.NoError(t, err)
assert.EqualValues(t, 0, total)
cnt, err := testEngine.Insert(&UserinfoCount{
Departname: "dev",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
total, err = testEngine.Where(cond).Count(new(UserinfoCount))
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
}
func TestSQLCount(t *testing.T) {
assert.NoError(t, prepareEngine())
type UserinfoCount2 struct {
Id int64
Departname string
}
type UserinfoBooks struct {
Id int64
Pid int64
IsOpen bool
}
assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
Count()
assert.NoError(t, err)
assert.EqualValues(t, 0, total)
}

View File

@ -6,14 +6,14 @@ package xorm
// Begin a transaction // Begin a transaction
func (session *Session) Begin() error { func (session *Session) Begin() error {
if session.IsAutoCommit { if session.isAutoCommit {
tx, err := session.DB().Begin() tx, err := session.DB().Begin()
if err != nil { if err != nil {
return err return err
} }
session.IsAutoCommit = false session.isAutoCommit = false
session.IsCommitedOrRollbacked = false session.isCommitedOrRollbacked = false
session.Tx = tx session.tx = tx
session.saveLastSQL("BEGIN TRANSACTION") session.saveLastSQL("BEGIN TRANSACTION")
} }
return nil return nil
@ -21,25 +21,23 @@ func (session *Session) Begin() error {
// Rollback When using transaction, you can rollback if any error // Rollback When using transaction, you can rollback if any error
func (session *Session) Rollback() error { func (session *Session) Rollback() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL(session.Engine.dialect.RollBackStr()) session.saveLastSQL(session.engine.dialect.RollBackStr())
session.IsCommitedOrRollbacked = true session.isCommitedOrRollbacked = true
return session.Tx.Rollback() return session.tx.Rollback()
} }
return nil return nil
} }
// Commit When using transaction, Commit will commit all operations. // Commit When using transaction, Commit will commit all operations.
func (session *Session) Commit() error { func (session *Session) Commit() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL("COMMIT") session.saveLastSQL("COMMIT")
session.IsCommitedOrRollbacked = true session.isCommitedOrRollbacked = true
var err error var err error
if err = session.Tx.Commit(); err == nil { if err = session.tx.Commit(); err == nil {
// handle processors after tx committed // handle processors after tx committed
closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) {
if closuresPtr != nil { if closuresPtr != nil {
for _, closure := range *closuresPtr { for _, closure := range *closuresPtr {
closure(bean) closure(bean)

View File

@ -16,19 +16,19 @@ import (
) )
func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil || if session.statement.RefTable == nil ||
session.Tx != nil { session.tx != nil {
return ErrCacheFailed return ErrCacheFailed
} }
oldhead, newsql := session.Statement.convertUpdateSQL(sqlStr) oldhead, newsql := session.statement.convertUpdateSQL(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.Engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable) newsql = filter.Do(newsql, session.engine.dialect, session.statement.RefTable)
} }
session.Engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
var nStart int var nStart int
if len(args) > 0 { if len(args) > 0 {
@ -39,13 +39,13 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
nStart = strings.Count(oldhead, "$") nStart = strings.Count(oldhead, "$")
} }
} }
table := session.Statement.RefTable table := session.statement.RefTable
cacher := session.Engine.getCacher2(table) cacher := session.engine.getCacher2(table)
tableName := session.Statement.TableName() tableName := session.statement.TableName()
session.Engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil { if err != nil {
rows, err := session.DB().Query(newsql, args[nStart:]...) rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
if err != nil { if err != nil {
return err return err
} }
@ -75,9 +75,9 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
ids = append(ids, pk) ids = append(ids, pk)
} }
session.Engine.logger.Debug("[cacheUpdate] find updated id", ids) session.engine.logger.Debug("[cacheUpdate] find updated id", ids)
} /*else { } /*else {
session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
cacher.DelIds(tableName, genSqlKey(newsql, args)) cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/ }*/
@ -103,36 +103,36 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
colName := sps2[len(sps2)-1] colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") { if strings.Contains(colName, "`") {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.Engine.QuoteStr()) { } else if strings.Contains(colName, session.engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
} else { } else {
session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed return ErrCacheFailed
} }
if col := table.GetColumn(colName); col != nil { if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean) fieldValue, err := col.ValueOf(bean)
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.engine.logger.Error(err)
} else { } else {
session.Engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.Statement.checkVersion { if col.IsVersion && session.statement.checkVersion {
fieldValue.SetInt(fieldValue.Int() + 1) fieldValue.SetInt(fieldValue.Int() + 1)
} else { } else {
fieldValue.Set(reflect.ValueOf(args[idx])) fieldValue.Set(reflect.ValueOf(args[idx]))
} }
} }
} else { } else {
session.Engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
colName, table.Name) colName, table.Name)
} }
} }
session.Engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
cacher.PutBean(tableName, sid, bean) cacher.PutBean(tableName, sid, bean)
} }
} }
session.Engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
return nil return nil
} }
@ -144,8 +144,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
// You should call UseBool if you have bool to use. // You should call UseBool if you have bool to use.
// 2.float32 & float64 may be not inexact as conditions // 2.float32 & float64 may be not inexact as conditions
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
defer session.resetStatement() if session.isAutoClose {
if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
@ -169,21 +168,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct var isStruct = t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.Statement.setRefValue(v); err != nil { if err := session.statement.setRefValue(v); err != nil {
return 0, err return 0, err
} }
if len(session.Statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
if session.Statement.ColumnStr == "" { if session.statement.ColumnStr == "" {
colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false, colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false,
false, false, session.Statement.allUseBool, session.Statement.useAllCols, false, false, session.statement.allUseBool, session.statement.useAllCols,
session.Statement.mustColumnMap, session.Statement.nullableMap, session.statement.mustColumnMap, session.statement.nullableMap,
session.Statement.columnMap, true, session.Statement.unscoped) session.statement.columnMap, true, session.statement.unscoped)
} else { } else {
colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true) colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -194,19 +193,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
bValue := reflect.Indirect(reflect.ValueOf(bean)) bValue := reflect.Indirect(reflect.ValueOf(bean))
for _, v := range bValue.MapKeys() { for _, v := range bValue.MapKeys() {
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?") colNames = append(colNames, session.engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface()) args = append(args, bValue.MapIndex(v).Interface())
} }
} else { } else {
return 0, ErrParamsType return 0, ErrParamsType
} }
table := session.Statement.RefTable table := session.statement.RefTable
if session.Statement.UseAutoTime && table != nil && table.Updated != "" { if session.statement.UseAutoTime && table != nil && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn() col := table.UpdatedColumn()
val, t := session.Engine.NowTime2(col.SQLType.Name) val, t := session.engine.NowTime2(col.SQLType.Name)
args = append(args, val) args = append(args, val)
var colName = col.Name var colName = col.Name
@ -219,45 +218,44 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
//for update action to like "column = column + ?" //for update action to like "column = column + ?"
incColumns := session.Statement.getInc() incColumns := session.statement.getInc()
for _, v := range incColumns { for _, v := range incColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?") colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg) args = append(args, v.arg)
} }
//for update action to like "column = column - ?" //for update action to like "column = column - ?"
decColumns := session.Statement.getDec() decColumns := session.statement.getDec()
for _, v := range decColumns { for _, v := range decColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?") colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg) args = append(args, v.arg)
} }
//for update action to like "column = expression" //for update action to like "column = expression"
exprColumns := session.Statement.getExpr() exprColumns := session.statement.getExpr()
for _, v := range exprColumns { for _, v := range exprColumns {
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr) colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
} }
if err = session.Statement.processIDParam(); err != nil { if err = session.statement.processIDParam(); err != nil {
return 0, err return 0, err
} }
var autoCond builder.Cond var autoCond builder.Cond
if !session.Statement.noAutoCondition && len(condiBean) > 0 { if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error var err error
autoCond, err = session.Statement.buildConds(session.Statement.RefTable, condiBean[0], true, true, false, true, false) autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
st := session.Statement st := &session.statement
defer session.resetStatement()
var sqlStr string var sqlStr string
var condArgs []interface{} var condArgs []interface{}
var condSQL string var condSQL string
cond := session.Statement.cond.And(autoCond) cond := session.statement.cond.And(autoCond)
var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion) var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion)
var verValue *reflect.Value var verValue *reflect.Value
if doIncVer { if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean) verValue, err = table.VersionColumn().ValueOf(bean)
@ -265,8 +263,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err return 0, err
} }
cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()}) cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
} }
condSQL, condArgs, err = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
@ -290,7 +288,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} else if st.Engine.dialect.DBType() == core.SQLITE { } else if st.Engine.dialect.DBType() == core.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) session.engine.Quote(session.statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil { if err != nil {
return 0, err return 0, err
@ -301,7 +299,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} else if st.Engine.dialect.DBType() == core.POSTGRES { } else if st.Engine.dialect.DBType() == core.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) session.engine.Quote(session.statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil { if err != nil {
return 0, err return 0, err
@ -315,7 +313,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table != nil && len(table.PrimaryKeys) == 1 { table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...) session.engine.Quote(session.statement.TableName()), condSQL), condArgs...)
condSQL, condArgs, err = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil { if err != nil {
@ -330,9 +328,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
if len(colNames) <= 0 {
return 0, errors.New("No content found to be updated")
}
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v", sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
top, top,
session.Engine.Quote(session.Statement.TableName()), session.engine.Quote(session.statement.TableName()),
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
condSQL) condSQL)
@ -346,19 +348,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
if table != nil { if table != nil {
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
cacher.ClearIds(session.Statement.TableName()) cacher.ClearIds(session.statement.TableName())
cacher.ClearBeans(session.Statement.TableName()) cacher.ClearBeans(session.statement.TableName())
} }
} }
// handle after update processors // handle after update processors
if session.IsAutoCommit { if session.isAutoCommit {
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {
closure(bean) closure(bean)
} }
if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
session.Engine.logger.Debug("[event]", session.Statement.TableName(), " has after update processor") session.engine.logger.Debug("[event]", session.statement.TableName(), " has after update processor")
processor.AfterUpdate() processor.AfterUpdate()
} }
} else { } else {

View File

@ -298,7 +298,7 @@ func TestUpdate1(t *testing.T) {
// update by id // update by id
user := Userinfo{Username: "xxx", Height: 1.2} user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.Id(ori.Uid).Update(&user) cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -311,7 +311,7 @@ func TestUpdate1(t *testing.T) {
} }
condi := Condi{"username": "zzz", "departname": ""} condi := Condi{"username": "zzz", "departname": ""}
cnt, err = testEngine.Table(&user).Id(ori.Uid).Update(&condi) cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -351,7 +351,7 @@ func TestUpdate1(t *testing.T) {
} }
userID := user.Uid userID := user.Uid
has, err := testEngine.Id(userID). has, err := testEngine.ID(userID).
And("username = ?", user.Username). And("username = ?", user.Username).
And("height = ?", user.Height). And("height = ?", user.Height).
And("departname = ?", ""). And("departname = ?", "").
@ -369,7 +369,7 @@ func TestUpdate1(t *testing.T) {
} }
updatedUser := &Userinfo{Username: "null data"} updatedUser := &Userinfo{Username: "null data"}
cnt, err = testEngine.Id(userID). cnt, err = testEngine.ID(userID).
Nullable("height", "departname", "is_man", "created"). Nullable("height", "departname", "is_man", "created").
Update(updatedUser) Update(updatedUser)
if err != nil { if err != nil {
@ -382,7 +382,7 @@ func TestUpdate1(t *testing.T) {
panic(err) panic(err)
} }
has, err = testEngine.Id(userID). has, err = testEngine.ID(userID).
And("username = ?", updatedUser.Username). And("username = ?", updatedUser.Username).
And("height IS NULL"). And("height IS NULL").
And("departname IS NULL"). And("departname IS NULL").
@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(userID).Delete(&Userinfo{}) cnt, err = testEngine.ID(userID).Delete(&Userinfo{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -445,7 +445,7 @@ func TestUpdate1(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(a.Id).Update(&Article{Name: "6"}) cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -474,14 +474,14 @@ func TestUpdate1(t *testing.T) {
} }
col2 := &UpdateAllCols{col1.Id, true, "", nil} col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.Id(col2.Id).AllCols().Update(col2) _, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
col3 := &UpdateAllCols{} col3 := &UpdateAllCols{}
has, err = testEngine.Id(col2.Id).Get(col3) has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -519,14 +519,14 @@ func TestUpdate1(t *testing.T) {
col2 := &UpdateMustCols{col1.Id, true, ""} col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.ColumnMapper.Obj2Table("Bool") boolStr := testEngine.ColumnMapper.Obj2Table("Bool")
stringStr := testEngine.ColumnMapper.Obj2Table("String") stringStr := testEngine.ColumnMapper.Obj2Table("String")
_, err = testEngine.Id(col2.Id).MustCols(boolStr, stringStr).Update(col2) _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
col3 := &UpdateMustCols{} col3 := &UpdateMustCols{}
has, err := testEngine.Id(col2.Id).Get(col3) has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -561,17 +561,27 @@ func TestUpdateIncrDecr(t *testing.T) {
colName := testEngine.ColumnMapper.Obj2Table("Cnt") colName := testEngine.ColumnMapper.Obj2Table("Cnt")
cnt, err := testEngine.Id(col1.Id).Incr(colName).Update(col1) cnt, err := testEngine.ID(col1.Id).Incr(colName).Update(col1)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
newCol := new(UpdateIncr) newCol := new(UpdateIncr)
has, err := testEngine.Id(col1.Id).Get(newCol) has, err := testEngine.ID(col1.Id).Get(newCol)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, 1, newCol.Cnt) assert.EqualValues(t, 1, newCol.Cnt)
cnt, err = testEngine.Id(col1.Id).Cols(colName).Incr(colName).Update(col1) cnt, err = testEngine.ID(col1.Id).Decr(colName).Update(col1)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
newCol = new(UpdateIncr)
has, err = testEngine.ID(col1.Id).Get(newCol)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, newCol.Cnt)
cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
} }
@ -616,12 +626,12 @@ func TestUpdateUpdated(t *testing.T) {
} }
ci := &UpdatedUpdate{} ci := &UpdatedUpdate{}
_, err = testEngine.Id(1).Update(ci) _, err = testEngine.ID(1).Update(ci)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
has, err := testEngine.Id(1).Get(di) has, err := testEngine.ID(1).Get(di)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -644,11 +654,11 @@ func TestUpdateUpdated(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ci2 := &UpdatedUpdate2{} ci2 := &UpdatedUpdate2{}
_, err = testEngine.Id(1).Update(ci2) _, err = testEngine.ID(1).Update(ci2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
has, err = testEngine.Id(1).Get(di2) has, err = testEngine.ID(1).Get(di2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -671,12 +681,12 @@ func TestUpdateUpdated(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ci3 := &UpdatedUpdate3{} ci3 := &UpdatedUpdate3{}
_, err = testEngine.Id(1).Update(ci3) _, err = testEngine.ID(1).Update(ci3)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
has, err = testEngine.Id(1).Get(di3) has, err = testEngine.ID(1).Get(di3)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -700,12 +710,12 @@ func TestUpdateUpdated(t *testing.T) {
} }
ci4 := &UpdatedUpdate4{} ci4 := &UpdatedUpdate4{}
_, err = testEngine.Id(1).Update(ci4) _, err = testEngine.ID(1).Update(ci4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
has, err = testEngine.Id(1).Get(di4) has, err = testEngine.ID(1).Get(di4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -728,12 +738,12 @@ func TestUpdateUpdated(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ci5 := &UpdatedUpdate5{} ci5 := &UpdatedUpdate5{}
_, err = testEngine.Id(1).Update(ci5) _, err = testEngine.ID(1).Update(ci5)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
has, err = testEngine.Id(1).Get(di5) has, err = testEngine.ID(1).Get(di5)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -786,7 +796,7 @@ func TestUpdateSameMapper(t *testing.T) {
} }
// update by id // update by id
user := Userinfo{Username: "xxx", Height: 1.2} user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.Id(ori.Uid).Update(&user) cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -799,7 +809,7 @@ func TestUpdateSameMapper(t *testing.T) {
} }
condi := Condi{"Username": "zzz", "Departname": ""} condi := Condi{"Username": "zzz", "Departname": ""}
cnt, err = testEngine.Table(&user).Id(ori.Uid).Update(&condi) cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -864,7 +874,7 @@ func TestUpdateSameMapper(t *testing.T) {
panic(err) panic(err)
} }
cnt, err = testEngine.Id(a.Id).Update(&Article{Name: "6"}) cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -891,14 +901,14 @@ func TestUpdateSameMapper(t *testing.T) {
} }
col2 := &UpdateAllCols{col1.Id, true, "", nil} col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.Id(col2.Id).AllCols().Update(col2) _, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
col3 := &UpdateAllCols{} col3 := &UpdateAllCols{}
has, err = testEngine.Id(col2.Id).Get(col3) has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -935,14 +945,14 @@ func TestUpdateSameMapper(t *testing.T) {
col2 := &UpdateMustCols{col1.Id, true, ""} col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.ColumnMapper.Obj2Table("Bool") boolStr := testEngine.ColumnMapper.Obj2Table("Bool")
stringStr := testEngine.ColumnMapper.Obj2Table("String") stringStr := testEngine.ColumnMapper.Obj2Table("String")
_, err = testEngine.Id(col2.Id).MustCols(boolStr, stringStr).Update(col2) _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
col3 := &UpdateMustCols{} col3 := &UpdateMustCols{}
has, err := testEngine.Id(col2.Id).Get(col3) has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -978,7 +988,7 @@ func TestUpdateSameMapper(t *testing.T) {
panic(err) panic(err)
} }
cnt, err := testEngine.Id(col1.Id).Incr("`Cnt`").Update(col1) cnt, err := testEngine.ID(col1.Id).Incr("`Cnt`").Update(col1)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -990,7 +1000,7 @@ func TestUpdateSameMapper(t *testing.T) {
} }
newCol := new(UpdateIncr) newCol := new(UpdateIncr)
has, err := testEngine.Id(col1.Id).Get(newCol) has, err := testEngine.ID(col1.Id).Get(newCol)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1093,3 +1103,53 @@ func TestBool(t *testing.T) {
} }
} }
} }
func TestNoUpdate(t *testing.T) {
assert.NoError(t, prepareEngine())
type NoUpdate struct {
Id int64
Content string
}
assertSync(t, new(NoUpdate))
_, err := testEngine.Insert(&NoUpdate{
Content: "test",
})
assert.NoError(t, err)
_, err = testEngine.ID(1).Update(&NoUpdate{})
assert.Error(t, err)
assert.EqualValues(t, "No content found to be updated", err.Error())
}
func TestNewUpdate(t *testing.T) {
assert.NoError(t, prepareEngine())
type TbUserInfo struct {
Id int64 `xorm:"pk autoincr unique BIGINT" json:"id"`
Phone string `xorm:"not null unique VARCHAR(20)" json:"phone"`
UserName string `xorm:"VARCHAR(20)" json:"user_name"`
Gender int `xorm:"default 0 INTEGER" json:"gender"`
Pw string `xorm:"VARCHAR(100)" json:"pw"`
Token string `xorm:"TEXT" json:"token"`
Avatar string `xorm:"TEXT" json:"avatar"`
Extras interface{} `xorm:"JSON" json:"extras"`
Created time.Time `xorm:"DATETIME created"`
Updated time.Time `xorm:"DATETIME updated"`
Deleted time.Time `xorm:"DATETIME deleted"`
}
assertSync(t, new(TbUserInfo))
targetUsr := TbUserInfo{Phone: "13126564922"}
changeUsr := TbUserInfo{Token: "ABCDEFG"}
af, err := testEngine.Update(&changeUsr, &targetUsr)
assert.NoError(t, err)
assert.EqualValues(t, 0, af)
af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", 13126564922).Update(&changeUsr)
assert.NoError(t, err)
assert.EqualValues(t, 0, af)
}

View File

@ -272,6 +272,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
fieldValue := *fieldValuePtr fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())
if fieldType == nil {
continue
}
requiredField := useAllCols requiredField := useAllCols
includeNil := useAllCols includeNil := useAllCols
@ -592,6 +595,22 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
return newColumns return newColumns
} }
func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, 0, len(statement.columnMap))
for col := range statement.columnMap {
fields := strings.Split(strings.TrimSpace(col), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
return newColumns
}
// Distinct generates "DISTINCT col1, col2 " statement // Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement { func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true statement.IsDistinct = true
@ -618,7 +637,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
statement.columnMap[strings.ToLower(nc)] = true statement.columnMap[strings.ToLower(nc)] = true
} }
newColumns := statement.col2NewColsWithQuote(columns...) newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement return statement
@ -890,17 +909,24 @@ func (statement *Statement) buildConds(table *core.Table, bean interface{}, incl
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
} }
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { func (statement *Statement) mergeConds(bean interface{}) error {
if !statement.noAutoCondition { if !statement.noAutoCondition {
var addedTableName = (len(statement.JoinStr) > 0) var addedTableName = (len(statement.JoinStr) > 0)
autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil { if err != nil {
return "", nil, err return err
} }
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
} }
if err := statement.processIDParam(); err != nil { if err := statement.processIDParam(); err != nil {
return err
}
return nil
}
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
@ -940,14 +966,12 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
columnStr = "*" columnStr = "*"
} }
var condSQL string
var condArgs []interface{}
var err error
if isStruct { if isStruct {
condSQL, condArgs, err = statement.genConds(bean) if err := statement.mergeConds(bean); err != nil {
} else { return "", nil, err
condSQL, condArgs, err = builder.ToSQL(statement.cond) }
} }
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -960,10 +984,16 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) { func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
statement.setRefValue(rValue(bean)) var condSQL string
var condArgs []interface{}
condSQL, condArgs, err := statement.genConds(bean) var err error
if len(beans) > 0 {
statement.setRefValue(rValue(beans[0]))
condSQL, condArgs, err = statement.genConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -86,7 +86,7 @@ func TestExtends(t *testing.T) {
} }
tu3 := &tempUser2{tempUser{0, "extends update"}, ""} tu3 := &tempUser2{tempUser{0, "extends update"}, ""}
_, err = testEngine.Id(tu2.TempUser.Id).Update(tu3) _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -124,7 +124,7 @@ func TestExtends(t *testing.T) {
} }
tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}}
_, err = testEngine.Id(tu9.TempUser2.TempUser.Id).Update(tu10) _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -168,7 +168,7 @@ func TestExtends(t *testing.T) {
} }
tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} tu6 := &tempUser3{&tempUser{0, "extends update"}, ""}
_, err = testEngine.Id(tu5.Temp.Id).Update(tu6) _, err = testEngine.ID(tu5.Temp.Id).Update(tu6)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)

View File

@ -51,7 +51,7 @@ func TestCreatedAndUpdated(t *testing.T) {
} }
u.Name = "xxx" u.Name = "xxx"
cnt, err = testEngine.Id(u.Id).Update(u) cnt, err = testEngine.ID(u.Id).Update(u)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -110,48 +110,62 @@ func TestStrangeName(t *testing.T) {
} }
} }
type CreatedUpdated struct {
Id int64
Name string
Value float64 `xorm:"numeric"`
Created time.Time `xorm:"created"`
Created2 time.Time `xorm:"created"`
Updated time.Time `xorm:"updated"`
}
func TestCreatedUpdated(t *testing.T) { func TestCreatedUpdated(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
err := testEngine.Sync(&CreatedUpdated{}) type CreatedUpdated struct {
if err != nil { Id int64
t.Error(err) Name string
panic(err) Value float64 `xorm:"numeric"`
Created time.Time `xorm:"created"`
Created2 time.Time `xorm:"created"`
Updated time.Time `xorm:"updated"`
} }
err := testEngine.Sync(&CreatedUpdated{})
assert.NoError(t, err)
c := &CreatedUpdated{Name: "test"} c := &CreatedUpdated{Name: "test"}
_, err = testEngine.Insert(c) _, err = testEngine.Insert(c)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
c2 := new(CreatedUpdated) c2 := new(CreatedUpdated)
has, err := testEngine.Id(c.Id).Get(c2) has, err := testEngine.ID(c.Id).Get(c2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if !has { assert.True(t, has)
panic(errors.New("no id"))
}
c2.Value -= 1 c2.Value -= 1
_, err = testEngine.Id(c2.Id).Update(c2) _, err = testEngine.ID(c2.Id).Update(c2)
if err != nil { assert.NoError(t, err)
t.Error(err) }
panic(err)
func TestCreatedUpdatedInt64(t *testing.T) {
assert.NoError(t, prepareEngine())
type CreatedUpdatedInt64 struct {
Id int64
Name string
Value float64 `xorm:"numeric"`
Created int64 `xorm:"created"`
Created2 int64 `xorm:"created"`
Updated int64 `xorm:"updated"`
} }
assertSync(t, &CreatedUpdatedInt64{})
c := &CreatedUpdatedInt64{Name: "test"}
_, err := testEngine.Insert(c)
assert.NoError(t, err)
c2 := new(CreatedUpdatedInt64)
has, err := testEngine.ID(c.Id).Get(c2)
assert.NoError(t, err)
assert.True(t, has)
c2.Value -= 1
_, err = testEngine.ID(c2.Id).Update(c2)
assert.NoError(t, err)
} }
type Lowercase struct { type Lowercase struct {
@ -270,3 +284,77 @@ func TestTagComment(t *testing.T) {
assert.EqualValues(t, 1, len(tables[0].Columns())) assert.EqualValues(t, 1, len(tables[0].Columns()))
assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment)
} }
func TestTagDefault(t *testing.T) {
assert.NoError(t, prepareEngine())
type DefaultStruct struct {
Id int64
Name string
Age int `xorm:"default(10)"`
}
assertSync(t, new(DefaultStruct))
cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{
Name: "test",
Age: 20,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s DefaultStruct
has, err := testEngine.ID(1).Get(&s)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 10, s.Age)
assert.EqualValues(t, "test", s.Name)
}
func TestTagsDirection(t *testing.T) {
assert.NoError(t, prepareEngine())
type OnlyFromDBStruct struct {
Id int64
Name string
Uuid string `xorm:"<- default '1'"`
}
assertSync(t, new(OnlyFromDBStruct))
cnt, err := testEngine.Insert(&OnlyFromDBStruct{
Name: "test",
Uuid: "2",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s OnlyFromDBStruct
has, err := testEngine.ID(1).Get(&s)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "1", s.Uuid)
assert.EqualValues(t, "test", s.Name)
type OnlyToDBStruct struct {
Id int64
Name string
Uuid string `xorm:"->"`
}
assertSync(t, new(OnlyToDBStruct))
cnt, err = testEngine.Insert(&OnlyToDBStruct{
Name: "test",
Uuid: "2",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s2 OnlyToDBStruct
has, err = testEngine.ID(1).Get(&s2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "", s2.Uuid)
assert.EqualValues(t, "test", s2.Name)
}

View File

@ -49,7 +49,7 @@ func TestVersion1(t *testing.T) {
} }
newVer := new(VersionS) newVer := new(VersionS)
has, err := testEngine.Id(ver.Id).Get(newVer) has, err := testEngine.ID(ver.Id).Get(newVer)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -67,7 +67,7 @@ func TestVersion1(t *testing.T) {
} }
newVer.Name = "-------" newVer.Name = "-------"
_, err = testEngine.Id(ver.Id).Update(newVer) _, err = testEngine.ID(ver.Id).Update(newVer)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -78,7 +78,7 @@ func TestVersion1(t *testing.T) {
} }
newVer = new(VersionS) newVer = new(VersionS)
has, err = testEngine.Id(ver.Id).Get(newVer) has, err = testEngine.ID(ver.Id).Get(newVer)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)

View File

@ -176,7 +176,7 @@ func TestNullStructUpdate(t *testing.T) {
item.Age = sql.NullInt64{23, true} item.Age = sql.NullInt64{23, true}
item.Height = sql.NullFloat64{0, false} // update to NULL item.Height = sql.NullFloat64{0, false} // update to NULL
affected, err := testEngine.Id(2).Cols("age", "height", "is_man").Update(item) affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -224,7 +224,7 @@ func TestNullStructUpdate(t *testing.T) {
// IsMan: sql.NullBool{true, true}, // IsMan: sql.NullBool{true, true},
} }
_, err := testEngine.AllCols().Id(6).Update(item) _, err := testEngine.AllCols().ID(6).Update(item)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -268,7 +268,7 @@ func TestNullStructFind(t *testing.T) {
if true { if true {
item := new(NullType) item := new(NullType)
has, err := testEngine.Id(1).Get(item) has, err := testEngine.ID(1).Get(item)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -305,7 +305,7 @@ func TestNullStructFind(t *testing.T) {
if true { if true {
item := make([]NullType, 0) item := make([]NullType, 0)
err := testEngine.Id(2).Find(&item) err := testEngine.ID(2).Find(&item)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -390,7 +390,7 @@ func TestNullStructDelete(t *testing.T) {
item := new(NullType) item := new(NullType)
_, err := testEngine.Id(1).Delete(item) _, err := testEngine.ID(1).Delete(item)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)

View File

@ -37,7 +37,7 @@ func TestArrayField(t *testing.T) {
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var arr ArrayStruct var arr ArrayStruct
has, err := testEngine.Id(1).Get(&arr) has, err := testEngine.ID(1).Get(&arr)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, as.Name, arr.Name) assert.Equal(t, as.Name, arr.Name)
@ -320,7 +320,7 @@ func TestCustomType2(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
user := UserCus{} user := UserCus{}
exist, err := testEngine.Id(1).Get(&user) exist, err := testEngine.ID(1).Get(&user)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exist) assert.True(t, exist)

View File

@ -17,7 +17,7 @@ import (
const ( const (
// Version show the xorm's version // Version show the xorm's version
Version string = "0.6.2.0605" Version string = "0.6.3.0713"
) )
func regDrvsNDialects() bool { func regDrvsNDialects() bool {
@ -50,10 +50,13 @@ func close(engine *Engine) {
engine.Close() engine.Close()
} }
func init() {
regDrvsNDialects()
}
// NewEngine new a db manager according to the parameter. Currently support four // NewEngine new a db manager according to the parameter. Currently support four
// drivers // drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
regDrvsNDialects()
driver := core.QueryDriver(driverName) driver := core.QueryDriver(driverName)
if driver == nil { if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName) return nil, fmt.Errorf("Unsupported driver name: %v", driverName)

View File

@ -12,6 +12,7 @@ import (
"github.com/go-xorm/core" "github.com/go-xorm/core"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
) )
var ( var (

View File

@ -496,7 +496,7 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
func (c *context) File(file string) (err error) { func (c *context) File(file string) (err error) {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
return ErrNotFound return NotFoundHandler(c)
} }
defer f.Close() defer f.Close()
@ -505,7 +505,7 @@ func (c *context) File(file string) (err error) {
file = filepath.Join(file, indexPage) file = filepath.Join(file, indexPage)
f, err = os.Open(file) f, err = os.Open(file)
if err != nil { if err != nil {
return ErrNotFound return NotFoundHandler(c)
} }
defer f.Close() defer f.Close()
if fi, err = f.Stat(); err != nil { if fi, err = f.Stat(); err != nil {
@ -525,7 +525,7 @@ func (c *context) Inline(file, name string) (err error) {
} }
func (c *context) contentDisposition(file, name, dispositionType string) (err error) { func (c *context) contentDisposition(file, name, dispositionType string) (err error) {
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%s", dispositionType, name)) c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name))
c.File(file) c.File(file)
return return
} }

View File

@ -187,7 +187,7 @@ func TestContext(t *testing.T) {
err = c.Attachment("_fixture/images/walle.png", "walle.png") err = c.Attachment("_fixture/images/walle.png", "walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
@ -197,7 +197,7 @@ func TestContext(t *testing.T) {
err = c.Inline("_fixture/images/walle.png", "walle.png") err = c.Inline("_fixture/images/walle.png", "walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "inline; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }

View File

@ -282,7 +282,7 @@ func New() (e *Echo) {
e.TLSServer.Handler = e e.TLSServer.Handler = e
e.HTTPErrorHandler = e.DefaultHTTPErrorHandler e.HTTPErrorHandler = e.DefaultHTTPErrorHandler
e.Binder = &DefaultBinder{} e.Binder = &DefaultBinder{}
e.Logger.SetLevel(log.OFF) e.Logger.SetLevel(log.ERROR)
e.stdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) e.stdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0)
e.pool.New = func() interface{} { e.pool.New = func() interface{} {
return e.NewContext(nil, nil) return e.NewContext(nil, nil)
@ -295,7 +295,7 @@ func New() (e *Echo) {
func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context {
return &context{ return &context{
request: r, request: r,
response: &Response{echo: e, Writer: w}, response: NewResponse(w, e),
store: make(Map), store: make(Map),
echo: e, echo: e,
pvalues: make([]string, *e.maxParam), pvalues: make([]string, *e.maxParam),

View File

@ -21,66 +21,66 @@ func (g *Group) Use(middleware ...MiddlewareFunc) {
// Allow all requests to reach the group as they might get dropped if router // Allow all requests to reach the group as they might get dropped if router
// doesn't find a match, making none of the group middleware process. // doesn't find a match, making none of the group middleware process.
g.echo.Any(path.Clean(g.prefix+"/*"), func(c Context) error { g.echo.Any(path.Clean(g.prefix+"/*"), func(c Context) error {
return ErrNotFound return NotFoundHandler(c)
}, g.middleware...) }, g.middleware...)
} }
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(CONNECT, path, h, m...) return g.Add(CONNECT, path, h, m...)
} }
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. // DELETE implements `Echo#DELETE()` for sub-routes within the Group.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(DELETE, path, h, m...) return g.Add(DELETE, path, h, m...)
} }
// GET implements `Echo#GET()` for sub-routes within the Group. // GET implements `Echo#GET()` for sub-routes within the Group.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(GET, path, h, m...) return g.Add(GET, path, h, m...)
} }
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. // HEAD implements `Echo#HEAD()` for sub-routes within the Group.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(HEAD, path, h, m...) return g.Add(HEAD, path, h, m...)
} }
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. // OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(OPTIONS, path, h, m...) return g.Add(OPTIONS, path, h, m...)
} }
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. // PATCH implements `Echo#PATCH()` for sub-routes within the Group.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(PATCH, path, h, m...) return g.Add(PATCH, path, h, m...)
} }
// POST implements `Echo#POST()` for sub-routes within the Group. // POST implements `Echo#POST()` for sub-routes within the Group.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(POST, path, h, m...) return g.Add(POST, path, h, m...)
} }
// PUT implements `Echo#PUT()` for sub-routes within the Group. // PUT implements `Echo#PUT()` for sub-routes within the Group.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(PUT, path, h, m...) return g.Add(PUT, path, h, m...)
} }
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. // TRACE implements `Echo#TRACE()` for sub-routes within the Group.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.add(TRACE, path, h, m...) return g.Add(TRACE, path, h, m...)
} }
// Any implements `Echo#Any()` for sub-routes within the Group. // Any implements `Echo#Any()` for sub-routes within the Group.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) { func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
for _, m := range methods { for _, m := range methods {
g.add(m, path, handler, middleware...) g.Add(m, path, handler, middleware...)
} }
} }
// Match implements `Echo#Match()` for sub-routes within the Group. // Match implements `Echo#Match()` for sub-routes within the Group.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) { func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
for _, m := range methods { for _, m := range methods {
g.add(m, path, handler, middleware...) g.Add(m, path, handler, middleware...)
} }
} }
@ -102,7 +102,8 @@ func (g *Group) File(path, file string) {
g.echo.File(g.prefix+path, file) g.echo.File(g.prefix+path, file)
} }
func (g *Group) add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { // Add implements `Echo#Add()` for sub-routes within the Group.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
// Combine into a new slice to avoid accidentally passing the same slice for // Combine into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the // multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls. // middleware from earlier calls.

View File

@ -141,7 +141,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if _, ok := config.Claims.(jwt.MapClaims); ok { if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.keyFunc) token, err = jwt.Parse(auth, config.keyFunc)
} else { } else {
claims := reflect.ValueOf(config.Claims).Interface().(jwt.Claims) t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc) token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
} }
if err == nil && token.Valid { if err == nil && token.Valid {

View File

@ -22,6 +22,42 @@ type jwtCustomClaims struct {
jwtCustomInfo jwtCustomInfo
} }
func TestJWTRace(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
validKey := []byte("secret")
h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: validKey,
})(handler)
makeReq := func(token string) echo.Context {
req := httptest.NewRequest(echo.GET, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
c := e.NewContext(req, res)
assert.NoError(t, h(c))
return c
}
c := makeReq(initialToken)
user := c.Get("user").(*jwt.Token)
claims := user.Claims.(*jwtCustomClaims)
assert.Equal(t, claims.Name, "John Doe")
makeReq(raceToken)
user = c.Get("user").(*jwt.Token)
claims = user.Claims.(*jwtCustomClaims)
// Initial context should still be "John Doe", not "Race Condition"
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
}
func TestJWT(t *testing.T) { func TestJWT(t *testing.T) {
e := echo.New() e := echo.New()
handler := func(c echo.Context) error { handler := func(c echo.Context) error {

View File

@ -142,6 +142,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
tgt := config.Balancer.Next() tgt := config.Balancer.Next()

View File

@ -20,6 +20,11 @@ type (
} }
) )
// NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
return &Response{Writer: w, echo: e}
}
// Header returns the header map for the writer that will be sent by // Header returns the header map for the writer that will be sent by
// WriteHeader. Changing the header after a call to WriteHeader (or Write) has // WriteHeader. Changing the header after a call to WriteHeader (or Write) has
// no effect unless the modified headers were declared as trailers by setting // no effect unless the modified headers were declared as trailers by setting

View File

@ -1,4 +1,4 @@
// +build go1.7,!go1.8 // +build go1.7, !go1.8
package echo package echo

6
vendor/github.com/lib/pq/array.go generated vendored
View File

@ -13,7 +13,7 @@ import (
var typeByteSlice = reflect.TypeOf([]byte{}) var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// Array returns the optimal driver.Valuer and sql.Scanner for an array or // Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension. // slice of any dimension.
@ -278,7 +278,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b
// TODO calculate the assign function for other types // TODO calculate the assign function for other types
// TODO repeat this section on the element type of arrays or slices (multidimensional) // TODO repeat this section on the element type of arrays or slices (multidimensional)
{ {
if reflect.PtrTo(rt).Implements(typeSqlScanner) { if reflect.PtrTo(rt).Implements(typeSQLScanner) {
// dest is always addressable because it is an element of a slice. // dest is always addressable because it is an element of a slice.
assign = func(src []byte, dest reflect.Value) (err error) { assign = func(src []byte, dest reflect.Value) (err error) {
ss := dest.Addr().Interface().(sql.Scanner) ss := dest.Addr().Interface().(sql.Scanner)
@ -587,7 +587,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
} }
} }
var del string = "," var del = ","
var err error var err error
var iv interface{} = rv.Interface() var iv interface{} = rv.Interface()

83
vendor/github.com/lib/pq/conn.go generated vendored
View File

@ -27,12 +27,12 @@ var (
ErrNotSupported = errors.New("pq: Unsupported command") ErrNotSupported = errors.New("pq: Unsupported command")
ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
errUnexpectedReady = errors.New("unexpected ReadyForQuery") errUnexpectedReady = errors.New("unexpected ReadyForQuery")
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
) )
type Driver struct{} type Driver struct{}
@ -131,7 +131,7 @@ type conn struct {
} }
// Handle driver-side settings in parsed connection string. // Handle driver-side settings in parsed connection string.
func (c *conn) handleDriverSettings(o values) (err error) { func (cn *conn) handleDriverSettings(o values) (err error) {
boolSetting := func(key string, val *bool) error { boolSetting := func(key string, val *bool) error {
if value, ok := o[key]; ok { if value, ok := o[key]; ok {
if value == "yes" { if value == "yes" {
@ -145,18 +145,18 @@ func (c *conn) handleDriverSettings(o values) (err error) {
return nil return nil
} }
err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
if err != nil { if err != nil {
return err return err
} }
err = boolSetting("binary_parameters", &c.binaryParameters) err = boolSetting("binary_parameters", &cn.binaryParameters)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (c *conn) handlePgpass(o values) { func (cn *conn) handlePgpass(o values) {
// if a password was supplied, do not process .pgpass // if a password was supplied, do not process .pgpass
if _, ok := o["password"]; ok { if _, ok := o["password"]; ok {
return return
@ -229,10 +229,10 @@ func (c *conn) handlePgpass(o values) {
} }
} }
func (c *conn) writeBuf(b byte) *writeBuf { func (cn *conn) writeBuf(b byte) *writeBuf {
c.scratch[0] = b cn.scratch[0] = b
return &writeBuf{ return &writeBuf{
buf: c.scratch[:5], buf: cn.scratch[:5],
pos: 1, pos: 1,
} }
} }
@ -310,9 +310,8 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
u, err := userCurrent() u, err := userCurrent()
if err != nil { if err != nil {
return nil, err return nil, err
} else {
o["user"] = u
} }
o["user"] = u
} }
cn := &conn{ cn := &conn{
@ -698,7 +697,7 @@ var emptyRows noRows
var _ driver.Result = noRows{} var _ driver.Result = noRows{}
func (noRows) LastInsertId() (int64, error) { func (noRows) LastInsertId() (int64, error) {
return 0, errNoLastInsertId return 0, errNoLastInsertID
} }
func (noRows) RowsAffected() (int64, error) { func (noRows) RowsAffected() (int64, error) {
@ -840,16 +839,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround() cn.postExecuteWorkaround()
return rows, nil return rows, nil
} else {
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
} }
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
} }
// Implement the optional "Execer" interface for one-shot queries // Implement the optional "Execer" interface for one-shot queries
@ -876,17 +874,16 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
cn.postExecuteWorkaround() cn.postExecuteWorkaround()
res, _, err = cn.readExecuteResponse("Execute") res, _, err = cn.readExecuteResponse("Execute")
return res, err return res, err
} else {
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st := cn.prepareTo(query, "")
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
} }
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st := cn.prepareTo(query, "")
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
} }
func (cn *conn) send(m *writeBuf) { func (cn *conn) send(m *writeBuf) {
@ -1147,10 +1144,10 @@ const formatText format = 0
const formatBinary format = 1 const formatBinary format = 1
// One result-column format code with the value 1 (i.e. all binary). // One result-column format code with the value 1 (i.e. all binary).
var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} var colFmtDataAllBinary = []byte{0, 1, 0, 1}
// No result-column format codes (i.e. all text). // No result-column format codes (i.e. all text).
var colFmtDataAllText []byte = []byte{0, 0} var colFmtDataAllText = []byte{0, 0}
type stmt struct { type stmt struct {
cn *conn cn *conn
@ -1515,7 +1512,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
cn.send(b) cn.send(b)
} }
func (c *conn) processParameterStatus(r *readBuf) { func (cn *conn) processParameterStatus(r *readBuf) {
var err error var err error
param := r.string() param := r.string()
@ -1526,13 +1523,13 @@ func (c *conn) processParameterStatus(r *readBuf) {
var minor int var minor int
_, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
if err == nil { if err == nil {
c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
} }
case "TimeZone": case "TimeZone":
c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
if err != nil { if err != nil {
c.parameterStatus.currentLocation = nil cn.parameterStatus.currentLocation = nil
} }
default: default:
@ -1540,8 +1537,8 @@ func (c *conn) processParameterStatus(r *readBuf) {
} }
} }
func (c *conn) processReadyForQuery(r *readBuf) { func (cn *conn) processReadyForQuery(r *readBuf) {
c.txnStatus = transactionStatus(r.byte()) cn.txnStatus = transactionStatus(r.byte())
} }
func (cn *conn) readReadyForQuery() { func (cn *conn) readReadyForQuery() {
@ -1556,9 +1553,9 @@ func (cn *conn) readReadyForQuery() {
} }
} }
func (c *conn) processBackendKeyData(r *readBuf) { func (cn *conn) processBackendKeyData(r *readBuf) {
c.processID = r.int32() cn.processID = r.int32()
c.secretKey = r.int32() cn.secretKey = r.int32()
} }
func (cn *conn) readParseResponse() { func (cn *conn) readParseResponse() {

View File

@ -136,7 +136,7 @@ func TestOpenURL(t *testing.T) {
testURL("postgresql://") testURL("postgresql://")
} }
const pgpass_file = "/tmp/pqgotest_pgpass" const pgpassFile = "/tmp/pqgotest_pgpass"
func TestPgpass(t *testing.T) { func TestPgpass(t *testing.T) {
if os.Getenv("TRAVIS") != "true" { if os.Getenv("TRAVIS") != "true" {
@ -172,10 +172,10 @@ func TestPgpass(t *testing.T) {
txn.Rollback() txn.Rollback()
} }
testAssert("", "ok", "missing .pgpass, unexpected error %#v") testAssert("", "ok", "missing .pgpass, unexpected error %#v")
os.Setenv("PGPASSFILE", pgpass_file) os.Setenv("PGPASSFILE", pgpassFile)
testAssert("host=/tmp", "fail", ", unexpected error %#v") testAssert("host=/tmp", "fail", ", unexpected error %#v")
os.Remove(pgpass_file) os.Remove(pgpassFile)
pgpass, err := os.OpenFile(pgpass_file, os.O_RDWR|os.O_CREATE, 0644) pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644)
if err != nil { if err != nil {
t.Fatalf("Unexpected error writing pgpass file %#v", err) t.Fatalf("Unexpected error writing pgpass file %#v", err)
} }
@ -213,7 +213,7 @@ localhost:*:*:*:pass_C
// wrong permissions for the pgpass file means it should be ignored // wrong permissions for the pgpass file means it should be ignored
assertPassword(values{"host": "example.com", "user": "foo"}, "") assertPassword(values{"host": "example.com", "user": "foo"}, "")
// fix the permissions and check if it has taken effect // fix the permissions and check if it has taken effect
os.Chmod(pgpass_file, 0600) os.Chmod(pgpassFile, 0600)
assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A")
assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback")
assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B")
@ -221,7 +221,7 @@ localhost:*:*:*:pass_C
assertPassword(values{"host": "", "user": "some_user"}, "pass_C") assertPassword(values{"host": "", "user": "some_user"}, "pass_C")
assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C")
// cleanup // cleanup
os.Remove(pgpass_file) os.Remove(pgpassFile)
os.Setenv("PGPASSFILE", "") os.Setenv("PGPASSFILE", "")
} }
@ -393,8 +393,8 @@ func TestEmptyQuery(t *testing.T) {
if _, err := res.RowsAffected(); err != errNoRowsAffected { if _, err := res.RowsAffected(); err != errNoRowsAffected {
t.Fatalf("expected %s, got %v", errNoRowsAffected, err) t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
} }
if _, err := res.LastInsertId(); err != errNoLastInsertId { if _, err := res.LastInsertId(); err != errNoLastInsertID {
t.Fatalf("expected %s, got %v", errNoLastInsertId, err) t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
} }
rows, err := db.Query("") rows, err := db.Query("")
if err != nil { if err != nil {
@ -425,8 +425,8 @@ func TestEmptyQuery(t *testing.T) {
if _, err := res.RowsAffected(); err != errNoRowsAffected { if _, err := res.RowsAffected(); err != errNoRowsAffected {
t.Fatalf("expected %s, got %v", errNoRowsAffected, err) t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
} }
if _, err := res.LastInsertId(); err != errNoLastInsertId { if _, err := res.LastInsertId(); err != errNoLastInsertID {
t.Fatalf("expected %s, got %v", errNoLastInsertId, err) t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
} }
rows, err = stmt.Query() rows, err = stmt.Query()
if err != nil { if err != nil {
@ -1053,16 +1053,16 @@ func TestIssue282(t *testing.T) {
db := openTestConn(t) db := openTestConn(t)
defer db.Close() defer db.Close()
var search_path string var searchPath string
err := db.QueryRow(` err := db.QueryRow(`
SET LOCAL search_path TO pg_catalog; SET LOCAL search_path TO pg_catalog;
SET LOCAL search_path TO pg_catalog; SET LOCAL search_path TO pg_catalog;
SHOW search_path`).Scan(&search_path) SHOW search_path`).Scan(&searchPath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if search_path != "pg_catalog" { if searchPath != "pg_catalog" {
t.Fatalf("unexpected search_path %s", search_path) t.Fatalf("unexpected search_path %s", searchPath)
} }
} }

View File

@ -370,17 +370,17 @@ func TestInfinityTimestamp(t *testing.T) {
t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String())
} }
y_1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) ym1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC)
y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC) y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC)
var s string var s string
err = db.QueryRow("SELECT $1::timestamp::text", y_1500).Scan(&s) err = db.QueryRow("SELECT $1::timestamp::text", ym1500).Scan(&s)
if err != nil { if err != nil {
t.Errorf("Encoding -infinity, expected no error, got %q", err) t.Errorf("Encoding -infinity, expected no error, got %q", err)
} }
if s != "-infinity" { if s != "-infinity" {
t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s)
} }
err = db.QueryRow("SELECT $1::timestamptz::text", y_1500).Scan(&s) err = db.QueryRow("SELECT $1::timestamptz::text", ym1500).Scan(&s)
if err != nil { if err != nil {
t.Errorf("Encoding -infinity, expected no error, got %q", err) t.Errorf("Encoding -infinity, expected no error, got %q", err)
} }

View File

@ -33,7 +33,7 @@ func TestDecodeUUIDBackend(t *testing.T) {
db := openTestConn(t) db := openTestConn(t)
defer db.Close() defer db.Close()
var s string = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a" var s = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a"
var scanned interface{} var scanned interface{}
err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned) err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned)

View File

@ -29,6 +29,15 @@ const (
backgroundMask = (backgroundRed | backgroundBlue | backgroundGreen | backgroundIntensity) backgroundMask = (backgroundRed | backgroundBlue | backgroundGreen | backgroundIntensity)
) )
const (
genericRead = 0x80000000
genericWrite = 0x40000000
)
const (
consoleTextmodeBuffer = 0x1
)
type wchar uint16 type wchar uint16
type short int16 type short int16
type dword uint32 type dword uint32
@ -69,14 +78,17 @@ var (
procGetConsoleCursorInfo = kernel32.NewProc("GetConsoleCursorInfo") procGetConsoleCursorInfo = kernel32.NewProc("GetConsoleCursorInfo")
procSetConsoleCursorInfo = kernel32.NewProc("SetConsoleCursorInfo") procSetConsoleCursorInfo = kernel32.NewProc("SetConsoleCursorInfo")
procSetConsoleTitle = kernel32.NewProc("SetConsoleTitleW") procSetConsoleTitle = kernel32.NewProc("SetConsoleTitleW")
procCreateConsoleScreenBuffer = kernel32.NewProc("CreateConsoleScreenBuffer")
) )
// Writer provide colorable Writer to the console // Writer provide colorable Writer to the console
type Writer struct { type Writer struct {
out io.Writer out io.Writer
handle syscall.Handle handle syscall.Handle
oldattr word althandle syscall.Handle
oldpos coord oldattr word
oldpos coord
rest bytes.Buffer
} }
// NewColorable return new instance of Writer which handle escape sequence from File. // NewColorable return new instance of Writer which handle escape sequence from File.
@ -407,7 +419,18 @@ func (w *Writer) Write(data []byte) (n int, err error) {
var csbi consoleScreenBufferInfo var csbi consoleScreenBufferInfo
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
er := bytes.NewReader(data) handle := w.handle
var er *bytes.Reader
if w.rest.Len() > 0 {
var rest bytes.Buffer
w.rest.WriteTo(&rest)
w.rest.Reset()
rest.Write(data)
er = bytes.NewReader(rest.Bytes())
} else {
er = bytes.NewReader(data)
}
var bw [1]byte var bw [1]byte
loop: loop:
for { for {
@ -426,28 +449,42 @@ loop:
} }
if c2 == ']' { if c2 == ']' {
if err := doTitleSequence(er); err != nil { w.rest.WriteByte(c1)
w.rest.WriteByte(c2)
er.WriteTo(&w.rest)
if bytes.IndexByte(w.rest.Bytes(), 0x07) == -1 {
break loop break loop
} }
er = bytes.NewReader(w.rest.Bytes()[2:])
err := doTitleSequence(er)
if err != nil {
break loop
}
w.rest.Reset()
continue continue
} }
if c2 != 0x5b { if c2 != 0x5b {
continue continue
} }
w.rest.WriteByte(c1)
w.rest.WriteByte(c2)
er.WriteTo(&w.rest)
var buf bytes.Buffer var buf bytes.Buffer
var m byte var m byte
for { for i, c := range w.rest.Bytes()[2:] {
c, err := er.ReadByte()
if err != nil {
break loop
}
if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' { if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' {
m = c m = c
er = bytes.NewReader(w.rest.Bytes()[2+i+1:])
w.rest.Reset()
break break
} }
buf.Write([]byte(string(c))) buf.Write([]byte(string(c)))
} }
if m == 0 {
break loop
}
switch m { switch m {
case 'A': case 'A':
@ -455,61 +492,64 @@ loop:
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.y -= short(n) csbi.cursorPosition.y -= short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'B': case 'B':
n, err = strconv.Atoi(buf.String()) n, err = strconv.Atoi(buf.String())
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.y += short(n) csbi.cursorPosition.y += short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'C': case 'C':
n, err = strconv.Atoi(buf.String()) n, err = strconv.Atoi(buf.String())
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x += short(n) csbi.cursorPosition.x += short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'D': case 'D':
n, err = strconv.Atoi(buf.String()) n, err = strconv.Atoi(buf.String())
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x -= short(n) csbi.cursorPosition.x -= short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) if csbi.cursorPosition.x < 0 {
csbi.cursorPosition.x = 0
}
procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'E': case 'E':
n, err = strconv.Atoi(buf.String()) n, err = strconv.Atoi(buf.String())
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x = 0 csbi.cursorPosition.x = 0
csbi.cursorPosition.y += short(n) csbi.cursorPosition.y += short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'F': case 'F':
n, err = strconv.Atoi(buf.String()) n, err = strconv.Atoi(buf.String())
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x = 0 csbi.cursorPosition.x = 0
csbi.cursorPosition.y -= short(n) csbi.cursorPosition.y -= short(n)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'G': case 'G':
n, err = strconv.Atoi(buf.String()) n, err = strconv.Atoi(buf.String())
if err != nil { if err != nil {
continue continue
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
csbi.cursorPosition.x = short(n - 1) csbi.cursorPosition.x = short(n - 1)
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'H': case 'H', 'f':
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
if buf.Len() > 0 { if buf.Len() > 0 {
token := strings.Split(buf.String(), ";") token := strings.Split(buf.String(), ";")
switch len(token) { switch len(token) {
@ -534,7 +574,7 @@ loop:
} else { } else {
csbi.cursorPosition.y = 0 csbi.cursorPosition.y = 0
} }
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition)))
case 'J': case 'J':
n := 0 n := 0
if buf.Len() > 0 { if buf.Len() > 0 {
@ -545,7 +585,7 @@ loop:
} }
var count, written dword var count, written dword
var cursor coord var cursor coord
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
switch n { switch n {
case 0: case 0:
cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y} cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y}
@ -557,8 +597,8 @@ loop:
cursor = coord{x: csbi.window.left, y: csbi.window.top} cursor = coord{x: csbi.window.left, y: csbi.window.top}
count = dword(csbi.size.x - csbi.cursorPosition.x + (csbi.size.y-csbi.cursorPosition.y)*csbi.size.x) count = dword(csbi.size.x - csbi.cursorPosition.x + (csbi.size.y-csbi.cursorPosition.y)*csbi.size.x)
} }
procFillConsoleOutputCharacter.Call(uintptr(w.handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written))) procFillConsoleOutputCharacter.Call(uintptr(handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputAttribute.Call(uintptr(w.handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written))) procFillConsoleOutputAttribute.Call(uintptr(handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
case 'K': case 'K':
n := 0 n := 0
if buf.Len() > 0 { if buf.Len() > 0 {
@ -567,26 +607,28 @@ loop:
continue continue
} }
} }
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
var cursor coord var cursor coord
var count, written dword
switch n { switch n {
case 0: case 0:
cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y} cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y}
count = dword(csbi.size.x - csbi.cursorPosition.x)
case 1: case 1:
cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y} cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y}
count = dword(csbi.size.x - csbi.cursorPosition.x)
case 2: case 2:
cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y} cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y}
count = dword(csbi.size.x)
} }
var count, written dword procFillConsoleOutputCharacter.Call(uintptr(handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
count = dword(csbi.size.x - csbi.cursorPosition.x) procFillConsoleOutputAttribute.Call(uintptr(handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputCharacter.Call(uintptr(w.handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
procFillConsoleOutputAttribute.Call(uintptr(w.handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written)))
case 'm': case 'm':
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
attr := csbi.attributes attr := csbi.attributes
cs := buf.String() cs := buf.String()
if cs == "" { if cs == "" {
procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(w.oldattr)) procSetConsoleTextAttribute.Call(uintptr(handle), uintptr(w.oldattr))
continue continue
} }
token := strings.Split(cs, ";") token := strings.Split(cs, ";")
@ -625,6 +667,21 @@ loop:
attr |= n256foreAttr[n256] attr |= n256foreAttr[n256]
i += 2 i += 2
} }
} else if len(token) == 5 && token[i+1] == "2" {
var r, g, b int
r, _ = strconv.Atoi(token[i+2])
g, _ = strconv.Atoi(token[i+3])
b, _ = strconv.Atoi(token[i+4])
i += 4
if r > 127 {
attr |= foregroundRed
}
if g > 127 {
attr |= foregroundGreen
}
if b > 127 {
attr |= foregroundBlue
}
} else { } else {
attr = attr & (w.oldattr & backgroundMask) attr = attr & (w.oldattr & backgroundMask)
} }
@ -652,6 +709,21 @@ loop:
attr |= n256backAttr[n256] attr |= n256backAttr[n256]
i += 2 i += 2
} }
} else if len(token) == 5 && token[i+1] == "2" {
var r, g, b int
r, _ = strconv.Atoi(token[i+2])
g, _ = strconv.Atoi(token[i+3])
b, _ = strconv.Atoi(token[i+4])
i += 4
if r > 127 {
attr |= backgroundRed
}
if g > 127 {
attr |= backgroundGreen
}
if b > 127 {
attr |= backgroundBlue
}
} else { } else {
attr = attr & (w.oldattr & foregroundMask) attr = attr & (w.oldattr & foregroundMask)
} }
@ -683,30 +755,52 @@ loop:
attr |= backgroundBlue attr |= backgroundBlue
} }
} }
procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(attr)) procSetConsoleTextAttribute.Call(uintptr(handle), uintptr(attr))
} }
} }
case 'h': case 'h':
var ci consoleCursorInfo
cs := buf.String() cs := buf.String()
if cs == "?25" { if cs == "5>" {
var ci consoleCursorInfo procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
procGetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci))) ci.visible = 0
procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?25" {
procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
ci.visible = 1 ci.visible = 1
procSetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci))) procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?1049" {
if w.althandle == 0 {
h, _, _ := procCreateConsoleScreenBuffer.Call(uintptr(genericRead|genericWrite), 0, 0, uintptr(consoleTextmodeBuffer), 0, 0)
w.althandle = syscall.Handle(h)
if w.althandle != 0 {
handle = w.althandle
}
}
} }
case 'l': case 'l':
var ci consoleCursorInfo
cs := buf.String() cs := buf.String()
if cs == "?25" { if cs == "5>" {
var ci consoleCursorInfo procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
procGetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci))) ci.visible = 1
procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?25" {
procGetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
ci.visible = 0 ci.visible = 0
procSetConsoleCursorInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&ci))) procSetConsoleCursorInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&ci)))
} else if cs == "?1049" {
if w.althandle != 0 {
syscall.CloseHandle(w.althandle)
w.althandle = 0
handle = w.handle
}
} }
case 's': case 's':
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
w.oldpos = csbi.cursorPosition w.oldpos = csbi.cursorPosition
case 'u': case 'u':
procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&w.oldpos))) procSetConsoleCursorPosition.Call(uintptr(handle), *(*uintptr)(unsafe.Pointer(&w.oldpos)))
} }
} }

View File

@ -65,7 +65,7 @@ FAQ
* Want to get time.Time with current locale * Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`. Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
* Can I use this in multiple routines concurrently? * Can I use this in multiple routines concurrently?

View File

@ -167,7 +167,7 @@ type SQLiteDriver struct {
// SQLiteConn implement sql.Conn. // SQLiteConn implement sql.Conn.
type SQLiteConn struct { type SQLiteConn struct {
dbMu sync.Mutex mu sync.Mutex
db *C.sqlite3 db *C.sqlite3
loc *time.Location loc *time.Location
txlock string txlock string
@ -182,6 +182,7 @@ type SQLiteTx struct {
// SQLiteStmt implement sql.Stmt. // SQLiteStmt implement sql.Stmt.
type SQLiteStmt struct { type SQLiteStmt struct {
mu sync.Mutex
c *SQLiteConn c *SQLiteConn
s *C.sqlite3_stmt s *C.sqlite3_stmt
t string t string
@ -202,6 +203,7 @@ type SQLiteRows struct {
cols []string cols []string
decltype []string decltype []string
cls bool cls bool
closed bool
done chan struct{} done chan struct{}
} }
@ -761,9 +763,9 @@ func (c *SQLiteConn) Close() error {
return c.lastError() return c.lastError()
} }
deleteHandles(c) deleteHandles(c)
c.dbMu.Lock() c.mu.Lock()
c.db = nil c.db = nil
c.dbMu.Unlock() c.mu.Unlock()
runtime.SetFinalizer(c, nil) runtime.SetFinalizer(c, nil)
return nil return nil
} }
@ -772,8 +774,8 @@ func (c *SQLiteConn) dbConnOpen() bool {
if c == nil { if c == nil {
return false return false
} }
c.dbMu.Lock() c.mu.Lock()
defer c.dbMu.Unlock() defer c.mu.Unlock()
return c.db != nil return c.db != nil
} }
@ -802,6 +804,8 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
// Close the statement. // Close the statement.
func (s *SQLiteStmt) Close() error { func (s *SQLiteStmt) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed { if s.closed {
return nil return nil
} }
@ -810,6 +814,7 @@ func (s *SQLiteStmt) Close() error {
return errors.New("sqlite statement with already closed database connection") return errors.New("sqlite statement with already closed database connection")
} }
rv := C.sqlite3_finalize(s.s) rv := C.sqlite3_finalize(s.s)
s.s = nil
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return s.c.lastError() return s.c.lastError()
} }
@ -866,10 +871,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
case float64: case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v)) rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte: case []byte:
if len(v) == 0 { ln := len(v)
if ln == 0 {
v = placeHolder v = placeHolder
} }
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v))) rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
case time.Time: case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0])) b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@ -904,6 +910,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
cols: nil, cols: nil,
decltype: nil, decltype: nil,
cls: s.cls, cls: s.cls,
closed: false,
done: make(chan struct{}), done: make(chan struct{}),
} }
@ -976,25 +983,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
// Close the rows. // Close the rows.
func (rc *SQLiteRows) Close() error { func (rc *SQLiteRows) Close() error {
if rc.s.closed { rc.s.mu.Lock()
if rc.s.closed || rc.closed {
rc.s.mu.Unlock()
return nil return nil
} }
rc.closed = true
if rc.done != nil { if rc.done != nil {
close(rc.done) close(rc.done)
} }
if rc.cls { if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close() return rc.s.Close()
} }
rv := C.sqlite3_reset(rc.s.s) rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
rc.s.mu.Unlock()
return rc.s.c.lastError() return rc.s.c.lastError()
} }
rc.s.mu.Unlock()
return nil return nil
} }
// Columns return column names. // Columns return column names.
func (rc *SQLiteRows) Columns() []string { func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) { rc.s.mu.Lock()
defer rc.s.mu.Unlock()
if rc.s.s != nil && rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc) rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ { for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i))) rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
@ -1003,9 +1018,8 @@ func (rc *SQLiteRows) Columns() []string {
return rc.cols return rc.cols
} }
// DeclTypes return column types. func (rc *SQLiteRows) declTypes() []string {
func (rc *SQLiteRows) DeclTypes() []string { if rc.s.s != nil && rc.decltype == nil {
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc) rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ { for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
@ -1014,8 +1028,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
return rc.decltype return rc.decltype
} }
// DeclTypes return column types.
func (rc *SQLiteRows) DeclTypes() []string {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
return rc.declTypes()
}
// Next move cursor to next. // Next move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error { func (rc *SQLiteRows) Next(dest []driver.Value) error {
if rc.s.closed {
return io.EOF
}
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
rv := C.sqlite3_step(rc.s.s) rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE { if rv == C.SQLITE_DONE {
return io.EOF return io.EOF
@ -1028,7 +1054,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return nil return nil
} }
rc.DeclTypes() rc.declTypes()
for i := range dest { for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) { switch C.sqlite3_column_type(rc.s.s, C.int(i)) {

View File

@ -8,9 +8,13 @@
package sqlite3 package sqlite3
import ( import (
"context"
"database/sql" "database/sql"
"fmt"
"math/rand"
"os" "os"
"testing" "testing"
"time"
) )
func TestNamedParams(t *testing.T) { func TestNamedParams(t *testing.T) {
@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) {
t.Error("Failed to db.QueryRow: not matched results") t.Error("Failed to db.QueryRow: not matched results")
} }
} }
var (
testTableStatements = []string{
`DROP TABLE IF EXISTS test_table`,
`
CREATE TABLE IF NOT EXISTS test_table (
key1 VARCHAR(64) PRIMARY KEY,
key_id VARCHAR(64) NOT NULL,
key2 VARCHAR(64) NOT NULL,
key3 VARCHAR(64) NOT NULL,
key4 VARCHAR(64) NOT NULL,
key5 VARCHAR(64) NOT NULL,
key6 VARCHAR(64) NOT NULL,
data BLOB NOT NULL
);`,
}
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
func randStringBytes(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func initDatabase(t *testing.T, db *sql.DB, rowCount int64) {
t.Logf("Executing db initializing statements")
for _, query := range testTableStatements {
_, err := db.Exec(query)
if err != nil {
t.Fatal(err)
}
}
for i := int64(0); i < rowCount; i++ {
query := `INSERT INTO test_table
(key1, key_id, key2, key3, key4, key5, key6, data)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?);`
args := []interface{}{
randStringBytes(50),
fmt.Sprint(i),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(50),
randStringBytes(2048),
}
_, err := db.Exec(query, args...)
if err != nil {
t.Fatal(err)
}
}
}
func TestShortTimeout(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared")
if err != nil {
t.Fatal(err)
}
defer db.Close()
initDatabase(t, db, 10000)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond)
defer cancel()
query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
FROM test_table
ORDER BY key2 ASC`
rows, err := db.QueryContext(ctx, query)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var key1, keyid, key2, key3, key4, key5, key6 string
var data []byte
err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data)
if err != nil {
break
}
}
if context.DeadlineExceeded != ctx.Err() {
t.Fatal(ctx.Err())
}
}

View File

@ -10,5 +10,6 @@ package sqlite3
#cgo CFLAGS: -DUSE_LIBSQLITE3 #cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo linux LDFLAGS: -lsqlite3 #cgo linux LDFLAGS: -lsqlite3
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3 #cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
#cgo solaris LDFLAGS: -lsqlite3
*/ */
import "C" import "C"

View File

@ -9,5 +9,6 @@ package sqlite3
/* /*
#cgo CFLAGS: -I. #cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl #cgo linux LDFLAGS: -ldl
#cgo solaris LDFLAGS: -lc
*/ */
import "C" import "C"

View File

@ -6,21 +6,22 @@
package sqlite3 package sqlite3
import ( import (
"bytes"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand"
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/mattn/go-sqlite3/sqlite3_test"
) )
func TempFilename(t *testing.T) string { func TempFilename(t *testing.T) string {
@ -870,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
} }
} }
func TestSuite(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
t.Fatal(err)
}
defer db.Close()
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
}
// TODO: Execer & Queryer currently disabled // TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82 // https://github.com/mattn/go-sqlite3/issues/82
func TestExecer(t *testing.T) { func TestExecer(t *testing.T) {
@ -1355,6 +1344,61 @@ func TestUpdateAndTransactionHooks(t *testing.T) {
} }
} }
func TestNilAndEmptyBytes(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
actualNil := []byte("use this to use an actual nil not a reference to nil")
emptyBytes := []byte{}
for tsti, tst := range []struct {
name string
columnType string
insertBytes []byte
expectedBytes []byte
}{
{"actual nil blob", "blob", actualNil, nil},
{"referenced nil blob", "blob", nil, nil},
{"empty blob", "blob", emptyBytes, emptyBytes},
{"actual nil text", "text", actualNil, nil},
{"referenced nil text", "text", nil, nil},
{"empty text", "text", emptyBytes, emptyBytes},
} {
if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
t.Fatal(tst.name, err)
}
if bytes.Equal(tst.insertBytes, actualNil) {
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
t.Fatal(tst.name, err)
}
} else {
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
t.Fatal(tst.name, err)
}
}
rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
if err != nil {
t.Fatal(tst.name, err)
}
if !rows.Next() {
t.Fatal(tst.name, "no rows")
}
var scanBytes []byte
if err = rows.Scan(&scanBytes); err != nil {
t.Fatal(tst.name, err)
}
if err = rows.Err(); err != nil {
t.Fatal(tst.name, err)
}
if tst.expectedBytes == nil && scanBytes != nil {
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
} else if !bytes.Equal(scanBytes, tst.expectedBytes) {
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
}
}
}
var customFunctionOnce sync.Once var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) { func BenchmarkCustomFunctions(b *testing.B) {
@ -1389,3 +1433,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
} }
} }
} }
func TestSuite(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
t.Fatal(err)
}
defer d.Close()
db = &TestDB{t, d, SQLITE, sync.Once{}}
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
if !testing.Short() {
for _, b := range benchmarks {
fmt.Printf("%-20s", b.Name)
r := testing.Benchmark(b.F)
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
}
}
db.tearDown()
}
// Dialect is a type of dialect of databases.
type Dialect int
// Dialects for databases.
const (
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
MYSQL // MYSQL mean MySQL dialect
)
// DB provide context for the tests
type TestDB struct {
*testing.T
*sql.DB
dialect Dialect
once sync.Once
}
var db *TestDB
// the following tables will be created and dropped during the test
var testTables = []string{"foo", "bar", "t", "bench"}
var tests = []testing.InternalTest{
{Name: "TestResult", F: testResult},
{Name: "TestBlobs", F: testBlobs},
{Name: "TestManyQueryRow", F: testManyQueryRow},
{Name: "TestTxQuery", F: testTxQuery},
{Name: "TestPreparedStmt", F: testPreparedStmt},
}
var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExec", F: benchmarkExec},
{Name: "BenchmarkQuery", F: benchmarkQuery},
{Name: "BenchmarkParams", F: benchmarkParams},
{Name: "BenchmarkStmt", F: benchmarkStmt},
{Name: "BenchmarkRows", F: benchmarkRows},
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
}
func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
res, err := db.Exec(sql, args...)
if err != nil {
db.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (db *TestDB) tearDown() {
for _, tbl := range testTables {
switch db.dialect {
case SQLITE:
db.mustExec("drop table if exists " + tbl)
case MYSQL, POSTGRESQL:
db.mustExec("drop table if exists " + tbl)
default:
db.Fatal("unknown dialect")
}
}
}
// q replaces ? parameters if needed
func (db *TestDB) q(sql string) string {
switch db.dialect {
case POSTGRESQL: // repace with $1, $2, ..
qrx := regexp.MustCompile(`\?`)
n := 0
return qrx.ReplaceAllStringFunc(sql, func(string) string {
n++
return "$" + strconv.Itoa(n)
})
}
return sql
}
func (db *TestDB) blobType(size int) string {
switch db.dialect {
case SQLITE:
return fmt.Sprintf("blob[%d]", size)
case POSTGRESQL:
return "bytea"
case MYSQL:
return fmt.Sprintf("VARBINARY(%d)", size)
}
panic("unknown dialect")
}
func (db *TestDB) serialPK() string {
switch db.dialect {
case SQLITE:
return "integer primary key autoincrement"
case POSTGRESQL:
return "serial primary key"
case MYSQL:
return "integer primary key auto_increment"
}
panic("unknown dialect")
}
func (db *TestDB) now() string {
switch db.dialect {
case SQLITE:
return "datetime('now')"
case POSTGRESQL:
return "now()"
case MYSQL:
return "now()"
}
panic("unknown dialect")
}
func makeBench() {
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
panic(err)
}
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
if err != nil {
panic(err)
}
defer st.Close()
for i := 0; i < 100; i++ {
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
panic(err)
}
}
}
// testResult is test for result
func testResult(t *testing.T) {
db.tearDown()
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
for i := 1; i < 3; i++ {
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
n, err := r.RowsAffected()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("got %v, want %v", n, 1)
}
n, err = r.LastInsertId()
if err != nil {
t.Fatal(err)
}
if n != int64(i) {
t.Errorf("got %v, want %v", n, i)
}
}
if _, err := db.Exec("error!"); err == nil {
t.Fatalf("expected error")
}
}
// testBlobs is test for blobs
func testBlobs(t *testing.T) {
db.tearDown()
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
// testManyQueryRow is test for many query row
func testManyQueryRow(t *testing.T) {
if testing.Short() {
t.Log("skipping in short mode")
return
}
db.tearDown()
db.mustExec("create table foo (id integer primary key, name varchar(50))")
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
// testTxQuery is test for transactional query
func testTxQuery(t *testing.T) {
db.tearDown()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
// testPreparedStmt is test for prepared statement
func testPreparedStmt(t *testing.T) {
db.tearDown()
db.mustExec("CREATE TABLE t (count INT)")
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}
// Benchmarks need to use panic() since b.Error errors are lost when
// running via testing.Benchmark() I would like to run these via go
// test -bench but calling Benchmark() from a benchmark test
// currently hangs go.
// benchmarkExec is benchmark for exec
func benchmarkExec(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := db.Exec("select 1"); err != nil {
panic(err)
}
}
}
// benchmarkQuery is benchmark for query
func benchmarkQuery(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// benchmarkParams is benchmark for params
func benchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// benchmarkStmt is benchmark for statement
func benchmarkStmt(b *testing.B) {
st, err := db.Prepare("select ?, ?, ?, ?")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// benchmarkRows is benchmark for rows
func benchmarkRows(b *testing.B) {
db.once.Do(makeBench)
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := db.Query("select * from bench")
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}
// benchmarkStmtRows is benchmark for statement rows
func benchmarkStmtRows(b *testing.B) {
db.once.Do(makeBench)
st, err := db.Prepare("select * from bench")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := st.Query()
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}

View File

@ -1,423 +0,0 @@
package sqlite3_test
import (
"database/sql"
"fmt"
"math/rand"
"regexp"
"strconv"
"sync"
"testing"
"time"
)
// Dialect is a type of dialect of databases.
type Dialect int
// Dialects for databases.
const (
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
MYSQL // MYSQL mean MySQL dialect
)
// DB provide context for the tests
type DB struct {
*testing.T
*sql.DB
dialect Dialect
once sync.Once
}
var db *DB
// the following tables will be created and dropped during the test
var testTables = []string{"foo", "bar", "t", "bench"}
var tests = []testing.InternalTest{
{Name: "TestBlobs", F: TestBlobs},
{Name: "TestManyQueryRow", F: TestManyQueryRow},
{Name: "TestTxQuery", F: TestTxQuery},
{Name: "TestPreparedStmt", F: TestPreparedStmt},
}
var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExec", F: BenchmarkExec},
{Name: "BenchmarkQuery", F: BenchmarkQuery},
{Name: "BenchmarkParams", F: BenchmarkParams},
{Name: "BenchmarkStmt", F: BenchmarkStmt},
{Name: "BenchmarkRows", F: BenchmarkRows},
{Name: "BenchmarkStmtRows", F: BenchmarkStmtRows},
}
// RunTests runs the SQL test suite
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
db = &DB{t, d, dialect, sync.Once{}}
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
if !testing.Short() {
for _, b := range benchmarks {
fmt.Printf("%-20s", b.Name)
r := testing.Benchmark(b.F)
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
}
}
db.tearDown()
}
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
res, err := db.Exec(sql, args...)
if err != nil {
db.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (db *DB) tearDown() {
for _, tbl := range testTables {
switch db.dialect {
case SQLITE:
db.mustExec("drop table if exists " + tbl)
case MYSQL, POSTGRESQL:
db.mustExec("drop table if exists " + tbl)
default:
db.Fatal("unknown dialect")
}
}
}
// q replaces ? parameters if needed
func (db *DB) q(sql string) string {
switch db.dialect {
case POSTGRESQL: // repace with $1, $2, ..
qrx := regexp.MustCompile(`\?`)
n := 0
return qrx.ReplaceAllStringFunc(sql, func(string) string {
n++
return "$" + strconv.Itoa(n)
})
}
return sql
}
func (db *DB) blobType(size int) string {
switch db.dialect {
case SQLITE:
return fmt.Sprintf("blob[%d]", size)
case POSTGRESQL:
return "bytea"
case MYSQL:
return fmt.Sprintf("VARBINARY(%d)", size)
}
panic("unknown dialect")
}
func (db *DB) serialPK() string {
switch db.dialect {
case SQLITE:
return "integer primary key autoincrement"
case POSTGRESQL:
return "serial primary key"
case MYSQL:
return "integer primary key auto_increment"
}
panic("unknown dialect")
}
func (db *DB) now() string {
switch db.dialect {
case SQLITE:
return "datetime('now')"
case POSTGRESQL:
return "now()"
case MYSQL:
return "now()"
}
panic("unknown dialect")
}
func makeBench() {
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
panic(err)
}
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
if err != nil {
panic(err)
}
defer st.Close()
for i := 0; i < 100; i++ {
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
panic(err)
}
}
}
// TestResult is test for result
func TestResult(t *testing.T) {
db.tearDown()
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
for i := 1; i < 3; i++ {
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
n, err := r.RowsAffected()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("got %v, want %v", n, 1)
}
n, err = r.LastInsertId()
if err != nil {
t.Fatal(err)
}
if n != int64(i) {
t.Errorf("got %v, want %v", n, i)
}
}
if _, err := db.Exec("error!"); err == nil {
t.Fatalf("expected error")
}
}
// TestBlobs is test for blobs
func TestBlobs(t *testing.T) {
db.tearDown()
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
// TestManyQueryRow is test for many query row
func TestManyQueryRow(t *testing.T) {
if testing.Short() {
t.Log("skipping in short mode")
return
}
db.tearDown()
db.mustExec("create table foo (id integer primary key, name varchar(50))")
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
// TestTxQuery is test for transactional query
func TestTxQuery(t *testing.T) {
db.tearDown()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
// TestPreparedStmt is test for prepared statement
func TestPreparedStmt(t *testing.T) {
db.tearDown()
db.mustExec("CREATE TABLE t (count INT)")
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}
// Benchmarks need to use panic() since b.Error errors are lost when
// running via testing.Benchmark() I would like to run these via go
// test -bench but calling Benchmark() from a benchmark test
// currently hangs go.
// BenchmarkExec is benchmark for exec
func BenchmarkExec(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := db.Exec("select 1"); err != nil {
panic(err)
}
}
}
// BenchmarkQuery is benchmark for query
func BenchmarkQuery(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// BenchmarkParams is benchmark for params
func BenchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// BenchmarkStmt is benchmark for statement
func BenchmarkStmt(b *testing.B) {
st, err := db.Prepare("select ?, ?, ?, ?")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
// BenchmarkRows is benchmark for rows
func BenchmarkRows(b *testing.B) {
db.once.Do(makeBench)
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := db.Query("select * from bench")
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}
// BenchmarkStmtRows is benchmark for statement rows
func BenchmarkStmtRows(b *testing.B) {
db.once.Do(makeBench)
st, err := db.Prepare("select * from bench")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := st.Query()
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}

View File

@ -1,11 +1,13 @@
language: go language: go
go: go:
- 1.2 - 1.2.x
- 1.3 - 1.3.x
- 1.4 - 1.4.x
- 1.5 - 1.5.x
- 1.6 - 1.6.x
- 1.7.x
- 1.8.x
- tip - tip
install: install:

View File

@ -65,4 +65,6 @@ var (
ShouldHappenWithin = assertions.ShouldHappenWithin ShouldHappenWithin = assertions.ShouldHappenWithin
ShouldNotHappenWithin = assertions.ShouldNotHappenWithin ShouldNotHappenWithin = assertions.ShouldNotHappenWithin
ShouldBeChronological = assertions.ShouldBeChronological ShouldBeChronological = assertions.ShouldBeChronological
ShouldBeError = assertions.ShouldBeError
) )

View File

@ -1 +0,0 @@
package main

View File

@ -35,13 +35,13 @@ func init() {
func flags() { func flags() {
flag.IntVar(&port, "port", 8080, "The port at which to serve http.") flag.IntVar(&port, "port", 8080, "The port at which to serve http.")
flag.StringVar(&host, "host", "127.0.0.1", "The host at which to serve http.") flag.StringVar(&host, "host", "127.0.0.1", "The host at which to serve http.")
flag.DurationVar(&nap, "poll", quarterSecond, "The interval to wait between polling the file system for changes (default: 250ms).") flag.DurationVar(&nap, "poll", quarterSecond, "The interval to wait between polling the file system for changes.")
flag.IntVar(&packages, "packages", 10, "The number of packages to test in parallel. Higher == faster but more costly in terms of computing. (default: 10)") flag.IntVar(&packages, "packages", 10, "The number of packages to test in parallel. Higher == faster but more costly in terms of computing.")
flag.StringVar(&gobin, "gobin", "go", "The path to the 'go' binary (default: search on the PATH).") flag.StringVar(&gobin, "gobin", "go", "The path to the 'go' binary (default: search on the PATH).")
flag.BoolVar(&cover, "cover", true, "Enable package-level coverage statistics. Requires Go 1.2+ and the go cover tool. (default: true)") flag.BoolVar(&cover, "cover", true, "Enable package-level coverage statistics. Requires Go 1.2+ and the go cover tool.")
flag.IntVar(&depth, "depth", -1, "The directory scanning depth. If -1, scan infinitely deep directory structures. 0: scan working directory. 1+: Scan into nested directories, limited to value. (default: -1)") flag.IntVar(&depth, "depth", -1, "The directory scanning depth. If -1, scan infinitely deep directory structures. 0: scan working directory. 1+: Scan into nested directories, limited to value.")
flag.StringVar(&timeout, "timeout", "0", "The test execution timeout if none is specified in the *.goconvey file (default is '0', which is the same as not providing this option).") flag.StringVar(&timeout, "timeout", "0", "The test execution timeout if none is specified in the *.goconvey file (default is '0', which is the same as not providing this option).")
flag.StringVar(&watchedSuffixes, "watchedSuffixes", ".go", "A comma separated list of file suffixes to watch for modifications (default: .go).") flag.StringVar(&watchedSuffixes, "watchedSuffixes", ".go", "A comma separated list of file suffixes to watch for modifications.")
flag.StringVar(&excludedDirs, "excludedDirs", "vendor,node_modules", "A comma separated list of directories that will be excluded from being watched") flag.StringVar(&excludedDirs, "excludedDirs", "vendor,node_modules", "A comma separated list of directories that will be excluded from being watched")
flag.StringVar(&workDir, "workDir", "", "set goconvey working directory (default current directory)") flag.StringVar(&workDir, "workDir", "", "set goconvey working directory (default current directory)")
flag.BoolVar(&autoLaunchBrowser, "launchBrowser", true, "toggle auto launching of browser (default: true)") flag.BoolVar(&autoLaunchBrowser, "launchBrowser", true, "toggle auto launching of browser (default: true)")

Some files were not shown because too many files have changed in this diff Show More