DB Migrations #67

Merged
konrad merged 44 commits from feature/migrations into master 2019-03-29 17:54:36 +00:00
53 changed files with 2715 additions and 1541 deletions
Showing only changes of commit 3e453d90ad - Show all commits

View File

@ -26,4 +26,63 @@ func Migrate() {
// Because init() does not guarantee the order in which these are added to the slice, // Because init() does not guarantee the order in which these are added to the slice,
// we need to sort them to ensure that they are in order // we need to sort them to ensure that they are in order
/*
XORM logs from try.vikunja.io:
Table users Column email db default is 'NULL', struct default is
Table users Column is_active db default is NULL, struct default is
Table users Column password_reset_token db default is 'NULL', struct default is
Table users Column email_confirm_token db default is 'NULL', struct default is
Table users Column created db default is NULL, struct default is
Table users Column updated db default is NULL, struct default is
Table list Column title db default is 'NULL', struct default is
Table list Column description db default is 'NULL', struct default is
Table list Column owner_id db default is NULL, struct default is
Table list Column namespace_id db default is NULL, struct default is
Table list Column created db default is NULL, struct default is
Table list Column updated db default is NULL, struct default is
Table tasks Column text db default is 'NULL', struct default is
Table tasks Column description db default is 'NULL', struct default is
Table tasks Column done db default is NULL, struct default is
Table tasks Column due_date_unix db default is NULL, struct default is
Table tasks Column reminders_unix db default is 'NULL', struct default is
Table tasks Column created_by_id db default is NULL, struct default is
Table tasks Column list_id db default is NULL, struct default is
Table tasks Column repeat_after db default is NULL, struct default is
Table tasks Column parent_task_id db default is NULL, struct default is
Table tasks Column priority db default is NULL, struct default is
Table tasks Column start_date_unix db default is NULL, struct default is
Table tasks Column end_date_unix db default is NULL, struct default is
Table tasks Column created db default is NULL, struct default is
Table tasks Column updated db default is NULL, struct default is
Table teams Column description db default is 'NULL', struct default is
Table teams Column created db default is NULL, struct default is
Table teams Column updated db default is NULL, struct default is
Table team_members Column admin db default is NULL, struct default is
Table team_members Column created db default is NULL, struct default is
Table team_list Column right db default is NULL, struct default is
Table team_list Column created db default is NULL, struct default is
Table team_list Column updated db default is NULL, struct default is
Table team_namespaces Column right db default is NULL, struct default is
Table team_namespaces Column created db default is NULL, struct default is
Table team_namespaces Column updated db default is NULL, struct default is
Table namespaces Column name db default is 'NULL', struct default is
Table namespaces Column description db default is 'NULL', struct default is
Table namespaces Column created db default is NULL, struct default is
Table namespaces Column updated db default is NULL, struct default is
Table users_list Column right db default is NULL, struct default is
Table users_list Column created db default is NULL, struct default is
Table users_list Column updated db default is NULL, struct default is
Table users_namespace Column right db default is NULL, struct default is
Table users_namespace Column created db default is NULL, struct default is
Table users_namespace Column updated db default is NULL, struct default is
Table task_assignees Column created db default is NULL, struct default is
Table labels Column description db default is 'NULL', struct default is
Table labels Column hex_color db default is 'NULL', struct default is
Table labels Column created db default is NULL, struct default is
Table labels Column updated db default is NULL, struct default is
Table label_task Column created db default is NULL, struct default is
Table tasks has column reminder_unix but struct has not related field
Table team_members has column updated but struct has not related field
*/
} }

View File

@ -1,26 +1,47 @@
# SQL builder # SQL builder
[![CircleCI](https://circleci.com/gh/go-xorm/builder/tree/master.svg?style=svg)](https://circleci.com/gh/go-xorm/builder/tree/master) [![GitCI.cn](https://gitci.cn/api/badges/go-xorm/builder/status.svg)](https://gitci.cn/go-xorm/builder) [![codecov](https://codecov.io/gh/go-xorm/builder/branch/master/graph/badge.svg)](https://codecov.io/gh/go-xorm/builder)
[![](https://goreportcard.com/badge/github.com/go-xorm/builder)](https://goreportcard.com/report/github.com/go-xorm/builder)
Package builder is a lightweight and fast SQL builder for Go and XORM. Package builder is a lightweight and fast SQL builder for Go and XORM.
Make sure you have installed Go 1.1+ and then: Make sure you have installed Go 1.8+ and then:
go get github.com/go-xorm/builder go get github.com/go-xorm/builder
# Insert # Insert
```Go ```Go
sql, args, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL() sql, args, err := builder.Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
// INSERT INTO table1 SELECT * FROM table2
sql, err := builder.Insert().Into("table1").Select().From("table2").ToBoundSQL()
// INSERT INTO table1 (a, b) SELECT b, c FROM table2
sql, err = builder.Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
``` ```
# Select # Select
```Go ```Go
// Simple Query
sql, args, err := Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL() sql, args, err := Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
// With join
sql, args, err = Select("c, d").From("table1").LeftJoin("table2", Eq{"table1.id": 1}.And(Lt{"table2.id": 3})). sql, args, err = Select("c, d").From("table1").LeftJoin("table2", Eq{"table1.id": 1}.And(Lt{"table2.id": 3})).
RightJoin("table3", "table2.id = table3.tid").Where(Eq{"a": 1}).ToSQL() RightJoin("table3", "table2.id = table3.tid").Where(Eq{"a": 1}).ToSQL()
// From sub query
sql, args, err := Select("sub.id").From(Select("c").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
// From union query
sql, args, err = Select("sub.id").From(
Select("id").From("table1").Where(Eq{"a": 1}).Union("all", Select("id").From("table1").Where(Eq{"a": 2})),"sub").
Where(Eq{"b": 1}).ToSQL()
// With order by
sql, args, err = Select("a", "b", "c").From("table1").Where(Eq{"f1": "v1", "f2": "v2"}).
OrderBy("a ASC").ToSQL()
// With limit.
// Be careful! You should set up specific dialect for builder before performing a query with LIMIT
sql, args, err = Dialect(MYSQL).Select("a", "b", "c").From("table1").OrderBy("a ASC").
Limit(5, 10).ToSQL()
``` ```
# Update # Update
@ -35,6 +56,16 @@ sql, args, err := Update(Eq{"a": 2}).From("table1").Where(Eq{"a": 1}).ToSQL()
sql, args, err := Delete(Eq{"a": 1}).From("table1").ToSQL() sql, args, err := Delete(Eq{"a": 1}).From("table1").ToSQL()
``` ```
# Union
```Go
sql, args, err := Select("*").From("a").Where(Eq{"status": "1"}).
Union("all", Select("*").From("a").Where(Eq{"status": "2"})).
Union("distinct", Select("*").From("a").Where(Eq{"status": "3"})).
Union("", Select("*").From("a").Where(Eq{"status": "4"})).
ToSQL()
```
# Conditions # Conditions
* `Eq` is a redefine of a map, you can give one or more conditions to `Eq` * `Eq` is a redefine of a map, you can give one or more conditions to `Eq`

View File

@ -4,6 +4,12 @@
package builder package builder
import (
sql2 "database/sql"
"fmt"
"sort"
)
type optype byte type optype byte
const ( const (
@ -12,6 +18,15 @@ const (
insertType // insert insertType // insert
updateType // update updateType // update
deleteType // delete deleteType // delete
unionType // union
)
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
) )
type join struct { type join struct {
@ -20,60 +35,115 @@ type join struct {
joinCond Cond joinCond Cond
} }
type union struct {
unionType string
builder *Builder
}
type limit struct {
limitN int
offset int
}
// Builder describes a SQL statement // Builder describes a SQL statement
type Builder struct { type Builder struct {
optype optype
tableName string dialect string
cond Cond isNested bool
selects []string into string
joins []join from string
inserts Eq subQuery *Builder
updates []Eq cond Cond
selects []string
joins []join
unions []union
limitation *limit
insertCols []string
insertVals []interface{}
updates []Eq
orderBy string
groupBy string
having string
} }
// Select creates a select Builder // Dialect sets the db dialect of Builder.
func Select(cols ...string) *Builder { func Dialect(dialect string) *Builder {
builder := &Builder{cond: NewCond()} builder := &Builder{cond: NewCond(), dialect: dialect}
return builder.Select(cols...) return builder
} }
// Insert creates an insert Builder // MySQL is shortcut of Dialect(MySQL)
func Insert(eq Eq) *Builder { func MySQL() *Builder {
builder := &Builder{cond: NewCond()} return Dialect(MYSQL)
return builder.Insert(eq)
} }
// Update creates an update Builder // MsSQL is shortcut of Dialect(MsSQL)
func Update(updates ...Eq) *Builder { func MsSQL() *Builder {
builder := &Builder{cond: NewCond()} return Dialect(MSSQL)
return builder.Update(updates...)
} }
// Delete creates a delete Builder // Oracle is shortcut of Dialect(Oracle)
func Delete(conds ...Cond) *Builder { func Oracle() *Builder {
builder := &Builder{cond: NewCond()} return Dialect(ORACLE)
return builder.Delete(conds...) }
// Postgres is shortcut of Dialect(Postgres)
func Postgres() *Builder {
return Dialect(POSTGRES)
}
// SQLite is shortcut of Dialect(SQLITE)
func SQLite() *Builder {
return Dialect(SQLITE)
} }
// Where sets where SQL // Where sets where SQL
func (b *Builder) Where(cond Cond) *Builder { func (b *Builder) Where(cond Cond) *Builder {
b.cond = b.cond.And(cond) if b.cond.IsValid() {
b.cond = b.cond.And(cond)
} else {
b.cond = cond
}
return b return b
} }
// From sets the table name // From sets from subject(can be a table name in string or a builder pointer) and its alias
func (b *Builder) From(tableName string) *Builder { func (b *Builder) From(subject interface{}, alias ...string) *Builder {
b.tableName = tableName switch subject.(type) {
case *Builder:
b.subQuery = subject.(*Builder)
if len(alias) > 0 {
b.from = alias[0]
} else {
b.isNested = true
}
case string:
b.from = subject.(string)
if len(alias) > 0 {
b.from = b.from + " " + alias[0]
}
}
return b return b
} }
// TableName returns the table name
func (b *Builder) TableName() string {
if b.optype == insertType {
return b.into
}
return b.from
}
// Into sets insert table name // Into sets insert table name
func (b *Builder) Into(tableName string) *Builder { func (b *Builder) Into(tableName string) *Builder {
b.tableName = tableName b.into = tableName
return b return b
} }
// Join sets join table and contions // Join sets join table and conditions
func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder { func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder {
switch joinCond.(type) { switch joinCond.(type) {
case Cond: case Cond:
@ -85,6 +155,50 @@ func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builde
return b return b
} }
// Union sets union conditions
func (b *Builder) Union(unionTp string, unionCond *Builder) *Builder {
var builder *Builder
if b.optype != unionType {
builder = &Builder{cond: NewCond()}
builder.optype = unionType
builder.dialect = b.dialect
builder.selects = b.selects
currentUnions := b.unions
// erase sub unions (actually append to new Builder.unions)
b.unions = nil
for e := range currentUnions {
currentUnions[e].builder.dialect = b.dialect
}
builder.unions = append(append(builder.unions, union{"", b}), currentUnions...)
} else {
builder = b
}
if unionCond != nil {
if unionCond.dialect == "" && builder.dialect != "" {
unionCond.dialect = builder.dialect
}
builder.unions = append(builder.unions, union{unionTp, unionCond})
}
return builder
}
// Limit sets limitN condition
func (b *Builder) Limit(limitN int, offset ...int) *Builder {
b.limitation = &limit{limitN: limitN}
if len(offset) > 0 {
b.limitation.offset = offset[0]
}
return b
}
// InnerJoin sets inner join // InnerJoin sets inner join
func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder { func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder {
return b.Join("INNER", joinTable, joinCond) return b.Join("INNER", joinTable, joinCond)
@ -113,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
// Select sets select SQL // Select sets select SQL
func (b *Builder) Select(cols ...string) *Builder { func (b *Builder) Select(cols ...string) *Builder {
b.selects = cols b.selects = cols
b.optype = selectType if b.optype == condType {
b.optype = selectType
}
return b return b
} }
@ -130,15 +246,52 @@ func (b *Builder) Or(cond Cond) *Builder {
} }
// Insert sets insert SQL // Insert sets insert SQL
func (b *Builder) Insert(eq Eq) *Builder { func (b *Builder) Insert(eq ...interface{}) *Builder {
b.inserts = eq if len(eq) > 0 {
var paramType = -1
for _, e := range eq {
switch t := e.(type) {
case Eq:
if paramType == -1 {
paramType = 0
}
if paramType != 0 {
break
}
for k, v := range t {
b.insertCols = append(b.insertCols, k)
b.insertVals = append(b.insertVals, v)
}
case string:
if paramType == -1 {
paramType = 1
}
if paramType != 1 {
break
}
b.insertCols = append(b.insertCols, t)
}
}
}
if len(b.insertCols) == len(b.insertVals) {
sort.Slice(b.insertVals, func(i, j int) bool {
return b.insertCols[i] < b.insertCols[j]
})
sort.Strings(b.insertCols)
}
b.optype = insertType b.optype = insertType
return b return b
} }
// Update sets update SQL // Update sets update SQL
func (b *Builder) Update(updates ...Eq) *Builder { func (b *Builder) Update(updates ...Eq) *Builder {
b.updates = updates b.updates = make([]Eq, 0, len(updates))
for _, update := range updates {
if update.IsValid() {
b.updates = append(b.updates, update)
}
}
b.optype = updateType b.optype = updateType
return b return b
} }
@ -153,8 +306,8 @@ func (b *Builder) Delete(conds ...Cond) *Builder {
// WriteTo implements Writer interface // WriteTo implements Writer interface
func (b *Builder) WriteTo(w Writer) error { func (b *Builder) WriteTo(w Writer) error {
switch b.optype { switch b.optype {
case condType: /*case condType:
return b.cond.WriteTo(w) return b.cond.WriteTo(w)*/
case selectType: case selectType:
return b.selectWriteTo(w) return b.selectWriteTo(w)
case insertType: case insertType:
@ -163,6 +316,8 @@ func (b *Builder) WriteTo(w Writer) error {
return b.updateWriteTo(w) return b.updateWriteTo(w)
case deleteType: case deleteType:
return b.deleteWriteTo(w) return b.deleteWriteTo(w)
case unionType:
return b.unionWriteTo(w)
} }
return ErrNotSupportType return ErrNotSupportType
@ -175,16 +330,48 @@ func (b *Builder) ToSQL() (string, []interface{}, error) {
return "", nil, err return "", nil, err
} }
return w.writer.String(), w.args, nil // in case of sql.NamedArg in args
for e := range w.args {
if namedArg, ok := w.args[e].(sql2.NamedArg); ok {
w.args[e] = namedArg.Value
}
}
var sql = w.writer.String()
var err error
switch b.dialect {
case ORACLE, MSSQL:
// This is for compatibility with different sql drivers
for e := range w.args {
w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e])
}
var prefix string
if b.dialect == ORACLE {
prefix = ":p"
} else {
prefix = "@p"
}
if sql, err = ConvertPlaceholder(sql, prefix); err != nil {
return "", nil, err
}
case POSTGRES:
if sql, err = ConvertPlaceholder(sql, "$"); err != nil {
return "", nil, err
}
}
return sql, w.args, nil
} }
// ToSQL convert a builder or condtions to SQL and args // ToBoundSQL
func ToSQL(cond interface{}) (string, []interface{}, error) { func (b *Builder) ToBoundSQL() (string, error) {
switch cond.(type) { w := NewWriter()
case Cond: if err := b.WriteTo(w); err != nil {
return condToSQL(cond.(Cond)) return "", err
case *Builder:
return cond.(*Builder).ToSQL()
} }
return "", nil, ErrNotSupportType
return ConvertToBoundSQL(w.writer.String(), w.args)
} }

View File

@ -5,16 +5,21 @@
package builder package builder
import ( import (
"errors"
"fmt" "fmt"
) )
// Delete creates a delete Builder
func Delete(conds ...Cond) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Delete(conds...)
}
func (b *Builder) deleteWriteTo(w Writer) error { func (b *Builder) deleteWriteTo(w Writer) error {
if len(b.tableName) <= 0 { if len(b.from) <= 0 {
return errors.New("no table indicated") return ErrNoTableName
} }
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil { if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil {
return err return err
} }

View File

@ -6,37 +6,63 @@ package builder
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
) )
func (b *Builder) insertWriteTo(w Writer) error { // Insert creates an insert Builder
if len(b.tableName) <= 0 { func Insert(eq ...interface{}) *Builder {
return errors.New("no table indicated") builder := &Builder{cond: NewCond()}
} return builder.Insert(eq...)
if len(b.inserts) <= 0 { }
return errors.New("no column to be update")
func (b *Builder) insertSelectWriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil {
return err
} }
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil { if len(b.insertCols) > 0 {
fmt.Fprintf(w, "(")
for _, col := range b.insertCols {
fmt.Fprintf(w, col)
}
fmt.Fprintf(w, ") ")
}
return b.selectWriteTo(w)
}
func (b *Builder) insertWriteTo(w Writer) error {
if len(b.into) <= 0 {
return ErrNoTableName
}
if len(b.insertCols) <= 0 && b.from == "" {
return ErrNoColumnToInsert
}
if b.into != "" && b.from != "" {
return b.insertSelectWriteTo(w)
}
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil {
return err return err
} }
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
var bs []byte var bs []byte
var valBuffer = bytes.NewBuffer(bs) var valBuffer = bytes.NewBuffer(bs)
var i = 0
for col, value := range b.inserts { for i, col := range b.insertCols {
value := b.insertVals[i]
fmt.Fprint(w, col) fmt.Fprint(w, col)
if e, ok := value.(expr); ok { if e, ok := value.(expr); ok {
fmt.Fprint(valBuffer, e.sql) fmt.Fprintf(valBuffer, "(%s)", e.sql)
args = append(args, e.args...) args = append(args, e.args...)
} else { } else {
fmt.Fprint(valBuffer, "?") fmt.Fprint(valBuffer, "?")
args = append(args, value) args = append(args, value)
} }
if i != len(b.inserts)-1 { if i != len(b.insertCols)-1 {
if _, err := fmt.Fprint(w, ","); err != nil { if _, err := fmt.Fprint(w, ","); err != nil {
return err return err
} }
@ -44,7 +70,6 @@ func (b *Builder) insertWriteTo(w Writer) error {
return err return err
} }
} }
i = i + 1
} }
if _, err := fmt.Fprint(w, ") Values ("); err != nil { if _, err := fmt.Fprint(w, ") Values ("); err != nil {

View File

@ -5,13 +5,24 @@
package builder package builder
import ( import (
"errors"
"fmt" "fmt"
) )
// Select creates a select Builder
func Select(cols ...string) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Select(cols...)
}
func (b *Builder) selectWriteTo(w Writer) error { func (b *Builder) selectWriteTo(w Writer) error {
if len(b.tableName) <= 0 { if len(b.from) <= 0 && !b.isNested {
return errors.New("no table indicated") return ErrNoTableName
}
// perform limit before writing to writer when b.dialect between ORACLE and MSSQL
// this avoid a duplicate writing problem in simple limit query
if b.limitation != nil && (b.dialect == ORACLE || b.dialect == MSSQL) {
return b.limitWriteTo(w)
} }
if _, err := fmt.Fprint(w, "SELECT "); err != nil { if _, err := fmt.Fprint(w, "SELECT "); err != nil {
@ -34,20 +45,101 @@ func (b *Builder) selectWriteTo(w Writer) error {
} }
} }
if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil { if b.subQuery == nil {
return err if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil {
return err
}
} else {
if b.cond.IsValid() && len(b.from) <= 0 {
return ErrUnnamedDerivedTable
}
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
return ErrInconsistentDialect
}
// dialect of sub-query will inherit from the main one (if not set up)
if b.dialect != "" && b.subQuery.dialect == "" {
b.subQuery.dialect = b.dialect
}
switch b.subQuery.optype {
case selectType, unionType:
fmt.Fprint(w, " FROM (")
if err := b.subQuery.WriteTo(w); err != nil {
return err
}
if len(b.from) == 0 {
fmt.Fprintf(w, ")")
} else {
fmt.Fprintf(w, ") %v", b.from)
}
default:
return ErrUnexpectedSubQuery
}
} }
for _, v := range b.joins { for _, v := range b.joins {
fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable) if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil {
return err
}
if err := v.joinCond.WriteTo(w); err != nil { if err := v.joinCond.WriteTo(w); err != nil {
return err return err
} }
} }
if _, err := fmt.Fprint(w, " WHERE "); err != nil { if b.cond.IsValid() {
return err if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := b.cond.WriteTo(w); err != nil {
return err
}
} }
return b.cond.WriteTo(w) if len(b.groupBy) > 0 {
if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil {
return err
}
}
if len(b.having) > 0 {
if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil {
return err
}
}
if len(b.orderBy) > 0 {
if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil {
return err
}
}
if b.limitation != nil {
if err := b.limitWriteTo(w); err != nil {
return err
}
}
return nil
}
// OrderBy orderBy SQL
func (b *Builder) OrderBy(orderBy string) *Builder {
b.orderBy = orderBy
return b
}
// GroupBy groupby SQL
func (b *Builder) GroupBy(groupby string) *Builder {
b.groupBy = groupby
return b
}
// Having having SQL
func (b *Builder) Having(having string) *Builder {
b.having = having
return b
} }

View File

@ -5,19 +5,24 @@
package builder package builder
import ( import (
"errors"
"fmt" "fmt"
) )
// Update creates an update Builder
func Update(updates ...Eq) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Update(updates...)
}
func (b *Builder) updateWriteTo(w Writer) error { func (b *Builder) updateWriteTo(w Writer) error {
if len(b.tableName) <= 0 { if len(b.from) <= 0 {
return errors.New("no table indicated") return ErrNoTableName
} }
if len(b.updates) <= 0 { if len(b.updates) <= 0 {
return errors.New("no column to be update") return ErrNoColumnToUpdate
} }
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.tableName); err != nil { if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.from); err != nil {
return err return err
} }

View File

@ -1,12 +0,0 @@
dependencies:
override:
# './...' is a relative pattern which means all subdirectories
- go get -t -d -v ./...
- go build -v
- go get -u github.com/golang/lint/golint
test:
override:
# './...' is a relative pattern which means all subdirectories
- golint ./...
- go test -v -race

View File

@ -5,7 +5,6 @@
package builder package builder
import ( import (
"bytes"
"io" "io"
) )
@ -19,15 +18,15 @@ var _ Writer = NewWriter()
// BytesWriter implments Writer and save SQL in bytes.Buffer // BytesWriter implments Writer and save SQL in bytes.Buffer
type BytesWriter struct { type BytesWriter struct {
writer *bytes.Buffer writer *StringBuilder
buffer []byte
args []interface{} args []interface{}
} }
// NewWriter creates a new string writer // NewWriter creates a new string writer
func NewWriter() *BytesWriter { func NewWriter() *BytesWriter {
w := &BytesWriter{} w := &BytesWriter{
w.writer = bytes.NewBuffer(w.buffer) writer: &StringBuilder{},
}
return w return w
} }
@ -73,15 +72,3 @@ func (condEmpty) Or(conds ...Cond) Cond {
func (condEmpty) IsValid() bool { func (condEmpty) IsValid() bool {
return false return false
} }
func condToSQL(cond Cond) (string, []interface{}, error) {
if cond == nil || !cond.IsValid() {
return "", nil, nil
}
w := NewWriter()
if err := cond.WriteTo(w); err != nil {
return "", nil, err
}
return w.writer.String(), w.args, nil
}

View File

@ -25,7 +25,9 @@ func And(conds ...Cond) Cond {
func (and condAnd) WriteTo(w Writer) error { func (and condAnd) WriteTo(w Writer) error {
for i, cond := range and { for i, cond := range and {
_, isOr := cond.(condOr) _, isOr := cond.(condOr)
if isOr { _, isExpr := cond.(expr)
wrap := isOr || isExpr
if wrap {
fmt.Fprint(w, "(") fmt.Fprint(w, "(")
} }
@ -34,7 +36,7 @@ func (and condAnd) WriteTo(w Writer) error {
return err return err
} }
if isOr { if wrap {
fmt.Fprint(w, ")") fmt.Fprint(w, ")")
} }

View File

@ -17,10 +17,35 @@ var _ Cond = Between{}
// WriteTo write data to Writer // WriteTo write data to Writer
func (between Between) WriteTo(w Writer) error { func (between Between) WriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "%s BETWEEN ? AND ?", between.Col); err != nil { if _, err := fmt.Fprintf(w, "%s BETWEEN ", between.Col); err != nil {
return err return err
} }
w.Append(between.LessVal, between.MoreVal) if lv, ok := between.LessVal.(expr); ok {
if err := lv.WriteTo(w); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, "?"); err != nil {
return err
}
w.Append(between.LessVal)
}
if _, err := fmt.Fprint(w, " AND "); err != nil {
return err
}
if mv, ok := between.MoreVal.(expr); ok {
if err := mv.WriteTo(w); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, "?"); err != nil {
return err
}
w.Append(between.MoreVal)
}
return nil return nil
} }

View File

@ -10,7 +10,13 @@ import "fmt"
func WriteMap(w Writer, data map[string]interface{}, op string) error { func WriteMap(w Writer, data map[string]interface{}, op string) error {
var args = make([]interface{}, 0, len(data)) var args = make([]interface{}, 0, len(data))
var i = 0 var i = 0
for k, v := range data { keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
for _, k := range keys {
v := data[k]
switch v.(type) { switch v.(type) {
case expr: case expr:
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil {

View File

@ -4,7 +4,10 @@
package builder package builder
import "fmt" import (
"fmt"
"sort"
)
// Incr implements a type used by Eq // Incr implements a type used by Eq
type Incr int type Incr int
@ -19,7 +22,8 @@ var _ Cond = Eq{}
func (eq Eq) opWriteTo(op string, w Writer) error { func (eq Eq) opWriteTo(op string, w Writer) error {
var i = 0 var i = 0
for k, v := range eq { for _, k := range eq.sortedKeys() {
v := eq[k]
switch v.(type) { switch v.(type) {
case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}: case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}:
if err := In(k, v).WriteTo(w); err != nil { if err := In(k, v).WriteTo(w); err != nil {
@ -94,3 +98,15 @@ func (eq Eq) Or(conds ...Cond) Cond {
func (eq Eq) IsValid() bool { func (eq Eq) IsValid() bool {
return len(eq) > 0 return len(eq) > 0
} }
// sortedKeys returns all keys of this Eq sorted with sort.Strings.
// It is used internally for consistent ordering when generating
// SQL, see https://github.com/go-xorm/builder/issues/10
func (eq Eq) sortedKeys() []string {
keys := make([]string, 0, len(eq))
for key := range eq {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View File

@ -16,7 +16,7 @@ func (like Like) WriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil { if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil {
return err return err
} }
// FIXME: if use other regular express, this will be failed. but for compitable, keep this // FIXME: if use other regular express, this will be failed. but for compatible, keep this
if like[1][0] == '%' || like[1][len(like[1])-1] == '%' { if like[1][0] == '%' || like[1][len(like[1])-1] == '%' {
w.Append(like[1]) w.Append(like[1])
} else { } else {

View File

@ -4,7 +4,10 @@
package builder package builder
import "fmt" import (
"fmt"
"sort"
)
// Neq defines not equal conditions // Neq defines not equal conditions
type Neq map[string]interface{} type Neq map[string]interface{}
@ -15,7 +18,8 @@ var _ Cond = Neq{}
func (neq Neq) WriteTo(w Writer) error { func (neq Neq) WriteTo(w Writer) error {
var args = make([]interface{}, 0, len(neq)) var args = make([]interface{}, 0, len(neq))
var i = 0 var i = 0
for k, v := range neq { for _, k := range neq.sortedKeys() {
v := neq[k]
switch v.(type) { switch v.(type) {
case []int, []int64, []string, []int32, []int16, []int8: case []int, []int64, []string, []int32, []int16, []int8:
if err := NotIn(k, v).WriteTo(w); err != nil { if err := NotIn(k, v).WriteTo(w); err != nil {
@ -76,3 +80,15 @@ func (neq Neq) Or(conds ...Cond) Cond {
func (neq Neq) IsValid() bool { func (neq Neq) IsValid() bool {
return len(neq) > 0 return len(neq) > 0
} }
// sortedKeys returns all keys of this Neq sorted with sort.Strings.
// It is used internally for consistent ordering when generating
// SQL, see https://github.com/go-xorm/builder/issues/10
func (neq Neq) sortedKeys() []string {
keys := make([]string, 0, len(neq))
for key := range neq {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View File

@ -21,6 +21,18 @@ func (not Not) WriteTo(w Writer) error {
if _, err := fmt.Fprint(w, "("); err != nil { if _, err := fmt.Fprint(w, "("); err != nil {
return err return err
} }
case Eq:
if len(not[0].(Eq)) > 1 {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
}
case Neq:
if len(not[0].(Neq)) > 1 {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
}
} }
if err := not[0].WriteTo(w); err != nil { if err := not[0].WriteTo(w); err != nil {
@ -32,6 +44,18 @@ func (not Not) WriteTo(w Writer) error {
if _, err := fmt.Fprint(w, ")"); err != nil { if _, err := fmt.Fprint(w, ")"); err != nil {
return err return err
} }
case Eq:
if len(not[0].(Eq)) > 1 {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
}
case Neq:
if len(not[0].(Neq)) > 1 {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
}
} }
return nil return nil

View File

@ -27,10 +27,12 @@ func (o condOr) WriteTo(w Writer) error {
for i, cond := range o { for i, cond := range o {
var needQuote bool var needQuote bool
switch cond.(type) { switch cond.(type) {
case condAnd: case condAnd, expr:
needQuote = true needQuote = true
case Eq: case Eq:
needQuote = (len(cond.(Eq)) > 1) needQuote = (len(cond.(Eq)) > 1)
case Neq:
needQuote = (len(cond.(Neq)) > 1)
} }
if needQuote { if needQuote {

View File

@ -8,9 +8,33 @@ import "errors"
var ( var (
// ErrNotSupportType not supported SQL type error // ErrNotSupportType not supported SQL type error
ErrNotSupportType = errors.New("not supported SQL type") ErrNotSupportType = errors.New("Not supported SQL type")
// ErrNoNotInConditions no NOT IN params error // ErrNoNotInConditions no NOT IN params error
ErrNoNotInConditions = errors.New("No NOT IN conditions") ErrNoNotInConditions = errors.New("No NOT IN conditions")
// ErrNoInConditions no IN params error // ErrNoInConditions no IN params error
ErrNoInConditions = errors.New("No IN conditions") ErrNoInConditions = errors.New("No IN conditions")
// ErrNeedMoreArguments need more arguments
ErrNeedMoreArguments = errors.New("Need more sql arguments")
// ErrNoTableName no table name
ErrNoTableName = errors.New("No table indicated")
// ErrNoColumnToInsert no column to update
ErrNoColumnToUpdate = errors.New("No column(s) to update")
// ErrNoColumnToInsert no column to update
ErrNoColumnToInsert = errors.New("No column(s) to insert")
// ErrNotSupportDialectType not supported dialect type error
ErrNotSupportDialectType = errors.New("Not supported dialect type")
// ErrNotUnexpectedUnionConditions using union in a wrong way
ErrNotUnexpectedUnionConditions = errors.New("Unexpected conditional fields in UNION query")
// ErrUnsupportedUnionMembers unexpected members in UNION query
ErrUnsupportedUnionMembers = errors.New("Unexpected members in UNION query")
// ErrUnexpectedSubQuery Unexpected sub-query in SELECT query
ErrUnexpectedSubQuery = errors.New("Unexpected sub-query in SELECT query")
// ErrDialectNotSetUp dialect is not setup yet
ErrDialectNotSetUp = errors.New("Dialect is not setup yet, try to use `Dialect(dbType)` at first")
// ErrInvalidLimitation offset or limit is not correct
ErrInvalidLimitation = errors.New("Offset or limit is not correct")
// ErrUnnamedDerivedTable Every derived table must have its own alias
ErrUnnamedDerivedTable = errors.New("Every derived table must have its own alias")
// ErrInconsistentDialect Inconsistent dialect in same builder
ErrInconsistentDialect = errors.New("Inconsistent dialect in same builder")
) )

View File

@ -147,12 +147,12 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
} }
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
} else { } else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName) return nil, fmt.Errorf("field %v is not valid", col.FieldName)
} }
} }
if !fieldValue.IsValid() { if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName) return nil, fmt.Errorf("field %v is not valid", col.FieldName)
} }
return &fieldValue, nil return &fieldValue, nil

View File

@ -49,7 +49,6 @@ func NewTable(name string, t reflect.Type) *Table {
} }
func (table *Table) columnsByName(name string) []*Column { func (table *Table) columnsByName(name string) []*Column {
n := len(name) n := len(name)
for k := range table.columnsMap { for k := range table.columnsMap {
@ -75,7 +74,6 @@ func (table *Table) GetColumn(name string) *Column {
} }
func (table *Table) GetColumnIdx(name string, idx int) *Column { func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name) cols := table.columnsByName(name)
if cols != nil && idx < len(cols) { if cols != nil && idx < len(cols) {

View File

@ -74,6 +74,7 @@ var (
NVarchar = "NVARCHAR" NVarchar = "NVARCHAR"
TinyText = "TINYTEXT" TinyText = "TINYTEXT"
Text = "TEXT" Text = "TEXT"
NText = "NTEXT"
Clob = "CLOB" Clob = "CLOB"
MediumText = "MEDIUMTEXT" MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT" LongText = "LONGTEXT"
@ -130,6 +131,7 @@ var (
NVarchar: TEXT_TYPE, NVarchar: TEXT_TYPE,
TinyText: TEXT_TYPE, TinyText: TEXT_TYPE,
Text: TEXT_TYPE, Text: TEXT_TYPE,
NText: TEXT_TYPE,
MediumText: TEXT_TYPE, MediumText: TEXT_TYPE,
LongText: TEXT_TYPE, LongText: TEXT_TYPE,
Uuid: TEXT_TYPE, Uuid: TEXT_TYPE,
@ -293,7 +295,7 @@ func SQLType2Type(st SQLType) reflect.Type {
return reflect.TypeOf(float32(1)) return reflect.TypeOf(float32(1))
case Double: case Double:
return reflect.TypeOf(float64(1)) return reflect.TypeOf(float64(1))
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: case Char, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
return reflect.TypeOf("") return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier:
return reflect.TypeOf([]byte{}) return reflect.TypeOf([]byte{})

View File

@ -28,3 +28,6 @@ temp_test.go
.vscode .vscode
xorm.test xorm.test
*.sqlite3 *.sqlite3
test.db.sql
.idea/

View File

@ -32,13 +32,10 @@ proposed functionality.
We appreciate any bug reports, but especially ones with self-contained We appreciate any bug reports, but especially ones with self-contained
(doesn't depend on code outside of xorm), minimal (can't be simplified (doesn't depend on code outside of xorm), minimal (can't be simplified
further) test cases. It's especially helpful if you can submit a pull further) test cases. It's especially helpful if you can submit a pull
request with just the failing test case (you'll probably want to request with just the failing test case(you can find some example test file like [session_get_test.go](https://github.com/go-xorm/xorm/blob/master/session_get_test.go)).
pattern it after the tests in
[base.go](https://github.com/go-xorm/tests/blob/master/base.go) AND
[benchmark.go](https://github.com/go-xorm/tests/blob/master/benchmark.go).
If you implements a new database interface, you maybe need to add a <databasename>_test.go file. If you implements a new database interface, you maybe need to add a test_<databasename>.sh file.
For example, [mysql_test.go](https://github.com/go-xorm/tests/blob/master/mysql/mysql_test.go) For example, [mysql_test.go](https://github.com/go-xorm/xorm/blob/master/test_mysql.sh)
### New functionality ### New functionality

View File

@ -1,3 +1,5 @@
# xorm
[中文](https://github.com/go-xorm/xorm/blob/master/README_CN.md) [中文](https://github.com/go-xorm/xorm/blob/master/README_CN.md)
Xorm is a simple and powerful ORM for Go. Xorm is a simple and powerful ORM for Go.
@ -6,7 +8,7 @@ Xorm is a simple and powerful ORM for Go.
[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm) [![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm)
[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) [![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
# Features ## Features
* Struct <-> Table Mapping Support * Struct <-> Table Mapping Support
@ -28,7 +30,13 @@ Xorm is a simple and powerful ORM for Go.
* SQL Builder support via [github.com/go-xorm/builder](https://github.com/go-xorm/builder) * SQL Builder support via [github.com/go-xorm/builder](https://github.com/go-xorm/builder)
# Drivers Support * Automatical Read/Write seperatelly
* Postgres schema support
* Context Cache support
## Drivers Support
Drivers for Go's sql package which currently support database/sql includes: Drivers for Go's sql package which currently support database/sql includes:
@ -46,43 +54,17 @@ Drivers for Go's sql package which currently support database/sql includes:
* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) * Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment)
# Changelog ## Installation
* **v0.6.3**
* merge tests to main project
* add `Exist` function
* add `SumInt` function
* Mysql now support read and create column comment.
* fix time related bugs.
* fix some other bugs.
* **v0.6.2**
* refactor tag parse methods
* add Scan features to Get
* add QueryString method
* **v0.6.0**
* remove support for ql
* add query condition builder support via [github.com/go-xorm/builder](https://github.com/go-xorm/builder), so `Where`, `And`, `Or`
methods can use `builder.Cond` as parameter
* add Sum, SumInt, SumInt64 and NotIn methods
* some bugs fixed
[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16)
# Installation
go get github.com/go-xorm/xorm go get github.com/go-xorm/xorm
# Documents ## Documents
* [Manual](http://xorm.io/docs) * [Manual](http://xorm.io/docs)
* [GoDoc](http://godoc.org/github.com/go-xorm/xorm) * [GoDoc](http://godoc.org/github.com/go-xorm/xorm)
* [GoWalker](http://gowalker.org/github.com/go-xorm/xorm) ## Quick Start
# Quick Start
* Create Engine * Create Engine
@ -106,15 +88,36 @@ type User struct {
err := engine.Sync2(new(User)) err := engine.Sync2(new(User))
``` ```
* `Query` runs a SQL string, the returned results is `[]map[string][]byte`, `QueryString` returns `[]map[string]string`. * Create Engine Group
```Go
dataSourceNameSlice := []string{masterDataSourceName, slave1DataSourceName, slave2DataSourceName}
engineGroup, err := xorm.NewEngineGroup(driverName, dataSourceNameSlice)
```
```Go
masterEngine, err := xorm.NewEngine(driverName, masterDataSourceName)
slave1Engine, err := xorm.NewEngine(driverName, slave1DataSourceName)
slave2Engine, err := xorm.NewEngine(driverName, slave2DataSourceName)
engineGroup, err := xorm.NewEngineGroup(masterEngine, []*Engine{slave1Engine, slave2Engine})
```
Then all place where `engine` you can just use `engineGroup`.
* `Query` runs a SQL string, the returned results is `[]map[string][]byte`, `QueryString` returns `[]map[string]string`, `QueryInterface` returns `[]map[string]interface{}`.
```Go ```Go
results, err := engine.Query("select * from user") results, err := engine.Query("select * from user")
results, err := engine.Where("a = 1").Query()
results, err := engine.QueryString("select * from user") results, err := engine.QueryString("select * from user")
results, err := engine.Where("a = 1").QueryString()
results, err := engine.QueryInterface("select * from user")
results, err := engine.Where("a = 1").QueryInterface()
``` ```
* `Execute` runs a SQL string, it returns `affected` and `error` * `Exec` runs a SQL string, it returns `affected` and `error`
```Go ```Go
affected, err := engine.Exec("update user set age = ? where name = ?", age, name) affected, err := engine.Exec("update user set age = ? where name = ?", age, name)
@ -125,62 +128,76 @@ affected, err := engine.Exec("update user set age = ? where name = ?", age, name
```Go ```Go
affected, err := engine.Insert(&user) affected, err := engine.Insert(&user)
// INSERT INTO struct () values () // INSERT INTO struct () values ()
affected, err := engine.Insert(&user1, &user2) affected, err := engine.Insert(&user1, &user2)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values () // INSERT INTO struct2 () values ()
affected, err := engine.Insert(&users) affected, err := engine.Insert(&users)
// INSERT INTO struct () values (),(),() // INSERT INTO struct () values (),(),()
affected, err := engine.Insert(&user1, &users) affected, err := engine.Insert(&user1, &users)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),() // INSERT INTO struct2 () values (),(),()
``` ```
* Query one record from database * `Get` query one record from database
```Go ```Go
has, err := engine.Get(&user) has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1 // SELECT * FROM user LIMIT 1
has, err := engine.Where("name = ?", name).Desc("id").Get(&user) has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name) has, err := engine.Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ? // SELECT name FROM user WHERE id = ?
var id int64 var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id) has, err := engine.Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ? // SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string) var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap) has, err := engine.Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ? // SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols)) var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ? // SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
* Check if one record exist on table * `Exist` check if one record exist on table
```Go ```Go
has, err := testEngine.Exist(new(RecordExist)) has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1 // SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{ has, err = testEngine.Exist(&RecordExist{
Name: "test1", Name: "test1",
}) })
// SELECT * FROM record_exist WHERE name = ? LIMIT 1 // SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1 // SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ? // select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist() has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1 // SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1 // SELECT * FROM record_exist WHERE name = ? LIMIT 1
``` ```
* Query multiple records from database, also you can use join and extends * `Find` query multiple records from database, also you can use join and extends
```Go ```Go
var users []User var users []User
err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users) err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users)
// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10 // SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0
type Detail struct { type Detail struct {
Id int64 Id int64
@ -193,14 +210,14 @@ type UserDetail struct {
} }
var users []UserDetail var users []UserDetail
err := engine.Table("user").Select("user.*, detail.*") err := engine.Table("user").Select("user.*, detail.*").
Join("INNER", "detail", "detail.user_id = user.id"). Join("INNER", "detail", "detail.user_id = user.id").
Where("user.name = ?", name).Limit(10, 0). Where("user.name = ?", name).Limit(10, 0).
Find(&users) Find(&users)
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 // SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0
``` ```
* Query multiple records and record by record handle, there are two methods Iterate and Rows * `Iterate` and `Rows` query multiple records and record by record handle, there are two methods Iterate and Rows
```Go ```Go
err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error { err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
@ -209,6 +226,13 @@ err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
}) })
// SELECT * FROM user // SELECT * FROM user
err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
user := bean.(*User)
return nil
})
// SELECT * FROM user Limit 0, 100
// SELECT * FROM user Limit 101, 100
rows, err := engine.Rows(&User{Name:name}) rows, err := engine.Rows(&User{Name:name})
// SELECT * FROM user // SELECT * FROM user
defer rows.Close() defer rows.Close()
@ -218,10 +242,10 @@ for rows.Next() {
} }
``` ```
* Update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on. * `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on.
```Go ```Go
affected, err := engine.Id(1).Update(&user) affected, err := engine.ID(1).Update(&user)
// UPDATE user SET ... Where id = ? // UPDATE user SET ... Where id = ?
affected, err := engine.Update(&user, &User{Name:name}) affected, err := engine.Update(&user, &User{Name:name})
@ -232,32 +256,50 @@ affected, err := engine.In("id", ids).Update(&user)
// UPDATE user SET ... Where id IN (?, ?, ?) // UPDATE user SET ... Where id IN (?, ?, ?)
// force update indicated columns by Cols // force update indicated columns by Cols
affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12}) affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ? // UPDATE user SET age = ?, updated=? Where id = ?
// force NOT update indicated columns by Omit // force NOT update indicated columns by Omit
affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12}) affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ? // UPDATE user SET age = ?, updated=? Where id = ?
affected, err := engine.Id(1).AllCols().Update(&user) affected, err := engine.ID(1).AllCols().Update(&user)
// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? // UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ?
``` ```
* Delete one or more records, Delete MUST have condition * `Delete` delete one or more records, Delete MUST have condition
```Go ```Go
affected, err := engine.Where(...).Delete(&user) affected, err := engine.Where(...).Delete(&user)
// DELETE FROM user Where ... // DELETE FROM user Where ...
affected, err := engine.Id(2).Delete(&user)
affected, err := engine.ID(2).Delete(&user)
// DELETE FROM user Where id = ?
``` ```
* Count records * `Count` count records
```Go ```Go
counts, err := engine.Count(&user) counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user // SELECT count(*) AS total FROM user
``` ```
* `Sum` sum functions
```Go
agesFloat64, err := engine.Sum(&user, "age")
// SELECT sum(age) AS total FROM user
agesInt64, err := engine.SumInt(&user, "age")
// SELECT sum(age) AS total FROM user
sumFloat64Slice, err := engine.Sums(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
sumInt64Slice, err := engine.SumsInt(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
```
* Query conditions builder * Query conditions builder
```Go ```Go
@ -265,7 +307,155 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
// SELECT id, name ... FROM user WHERE a NOT IN (?, ?) AND b IN (?, ?, ?) // SELECT id, name ... FROM user WHERE a NOT IN (?, ?) AND b IN (?, ?, ?)
``` ```
# Cases * Multiple operations in one go routine, no transation here but resue session memory
```Go
session := engine.NewSession()
defer session.Close()
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
return nil
```
* Transation should on one go routine. There is transaction and resue session memory
```Go
session := engine.NewSession()
defer session.Close()
// add Begin() before any action
if err := session.Begin(); err != nil {
// if returned then will rollback automatically
return err
}
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
// add Commit() after all actions
return session.Commit()
```
* Or you can use `Transaction` to replace above codes.
```Go
res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) {
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return nil, err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return nil, err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return nil, err
}
return nil, nil
})
```
* Context Cache, if enabled, current query result will be cached on session and be used by next same statement on the same session.
```Go
sess := engine.NewSession()
defer sess.Close()
var context = xorm.NewMemoryContextCache()
var c2 ContextGetStruct
has, err := sess.ID(1).ContextCache(context).Get(&c2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c2.Id)
assert.EqualValues(t, "1", c2.Name)
sql, args := sess.LastSQL()
assert.True(t, len(sql) > 0)
assert.True(t, len(args) > 0)
var c3 ContextGetStruct
has, err = sess.ID(1).ContextCache(context).Get(&c3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c3.Id)
assert.EqualValues(t, "1", c3.Name)
sql, args = sess.LastSQL()
assert.True(t, len(sql) == 0)
assert.True(t, len(args) == 0)
```
## Contributing
If you want to pull request, please see [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md). And we also provide [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) to discuss.
## Credits
### Contributors
This project exists thanks to all the people who contribute. [[Contribute](CONTRIBUTING.md)].
<a href="graphs/contributors"><img src="https://opencollective.com/xorm/contributors.svg?width=890&button=false" /></a>
### Backers
Thank you to all our backers! 🙏 [[Become a backer](https://opencollective.com/xorm#backer)]
<a href="https://opencollective.com/xorm#backers" target="_blank"><img src="https://opencollective.com/xorm/backers.svg?width=890"></a>
### Sponsors
Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [[Become a sponsor](https://opencollective.com/xorm#sponsor)]
## Changelog
* **v0.7.0**
* Some bugs fixed
* **v0.6.6**
* Some bugs fixed
* **v0.6.5**
* Postgres schema support
* vgo support
* Add FindAndCount
* Database special params support via NewEngineWithParams
* Some bugs fixed
* **v0.6.4**
* Automatical Read/Write seperatelly
* Query/QueryString/QueryInterface and action with Where/And
* Get support non-struct variables
* BufferSize on Iterate
* fix some other bugs.
[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16)
## Cases
* [studygolang](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang) * [studygolang](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
@ -301,15 +491,6 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [go-blog](http://wangcheng.me) - [github.com/easykoo/go-blog](https://github.com/easykoo/go-blog) * [go-blog](http://wangcheng.me) - [github.com/easykoo/go-blog](https://github.com/easykoo/go-blog)
# Discuss ## LICENSE
Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) BSD License [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)
# Contributing
If you want to pull request, please see [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)
# LICENSE
BSD License
[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)

View File

@ -22,6 +22,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 支持级联加载Struct * 支持级联加载Struct
* Schema支持仅Postgres
* 支持缓存 * 支持缓存
* 支持根据数据库自动生成xorm的结构体 * 支持根据数据库自动生成xorm的结构体
@ -30,6 +32,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 内置SQL Builder支持 * 内置SQL Builder支持
* 上下文缓存支持
## 驱动支持 ## 驱动支持
目前支持的Go数据库驱动和对应的数据库如下 目前支持的Go数据库驱动和对应的数据库如下
@ -50,35 +54,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持) * Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持)
## 更新日志
* **v0.6.3**
* 合并单元测试到主工程
* 新增`Exist`方法
* 新增`SumInt`方法
* Mysql新增读取和创建字段注释支持
* 新增`SetConnMaxLifetime`方法
* 修正了时间相关的Bug
* 修复了一些其它Bug
* **v0.6.2**
* 重构Tag解析方式
* Get方法新增类似Scan的特性
* 新增 QueryString 方法
* **v0.6.0**
* 去除对 ql 的支持
* 新增条件查询分析器 [github.com/go-xorm/builder](https://github.com/go-xorm/builder), 从因此 `Where, And, Or` 函数
将可以用 `builder.Cond` 作为条件组合
* 新增 Sum, SumInt, SumInt64 和 NotIn 函数
* Bug修正
* **v0.5.0**
* logging接口进行不兼容改变
* Bug修正
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## 安装 ## 安装
go get github.com/go-xorm/xorm go get github.com/go-xorm/xorm
@ -115,12 +90,33 @@ type User struct {
err := engine.Sync2(new(User)) err := engine.Sync2(new(User))
``` ```
* `Query` 最原始的也支持SQL语句查询返回的结果类型为 []map[string][]byte。`QueryString` 返回 []map[string]string * 创建Engine组
```Go
dataSourceNameSlice := []string{masterDataSourceName, slave1DataSourceName, slave2DataSourceName}
engineGroup, err := xorm.NewEngineGroup(driverName, dataSourceNameSlice)
```
```Go
masterEngine, err := xorm.NewEngine(driverName, masterDataSourceName)
slave1Engine, err := xorm.NewEngine(driverName, slave1DataSourceName)
slave2Engine, err := xorm.NewEngine(driverName, slave2DataSourceName)
engineGroup, err := xorm.NewEngineGroup(masterEngine, []*Engine{slave1Engine, slave2Engine})
```
所有使用 `engine` 都可以简单的用 `engineGroup` 来替换。
* `Query` 最原始的也支持SQL语句查询返回的结果类型为 []map[string][]byte。`QueryString` 返回 []map[string]string, `QueryInterface` 返回 `[]map[string]interface{}`.
```Go ```Go
results, err := engine.Query("select * from user") results, err := engine.Query("select * from user")
results, err := engine.Where("a = 1").Query()
results, err := engine.QueryString("select * from user") results, err := engine.QueryString("select * from user")
results, err := engine.Where("a = 1").QueryString()
results, err := engine.QueryInterface("select * from user")
results, err := engine.Where("a = 1").QueryInterface()
``` ```
* `Exec` 执行一个SQL语句 * `Exec` 执行一个SQL语句
@ -129,67 +125,81 @@ results, err := engine.QueryString("select * from user")
affected, err := engine.Exec("update user set age = ? where name = ?", age, name) affected, err := engine.Exec("update user set age = ? where name = ?", age, name)
``` ```
* 插入一条或者多条记录 * `Insert` 插入一条或者多条记录
```Go ```Go
affected, err := engine.Insert(&user) affected, err := engine.Insert(&user)
// INSERT INTO struct () values () // INSERT INTO struct () values ()
affected, err := engine.Insert(&user1, &user2) affected, err := engine.Insert(&user1, &user2)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values () // INSERT INTO struct2 () values ()
affected, err := engine.Insert(&users) affected, err := engine.Insert(&users)
// INSERT INTO struct () values (),(),() // INSERT INTO struct () values (),(),()
affected, err := engine.Insert(&user1, &users) affected, err := engine.Insert(&user1, &users)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),() // INSERT INTO struct2 () values (),(),()
``` ```
* 查询单条记录 * `Get` 查询单条记录
```Go ```Go
has, err := engine.Get(&user) has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1 // SELECT * FROM user LIMIT 1
has, err := engine.Where("name = ?", name).Desc("id").Get(&user) has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name) has, err := engine.Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ? // SELECT name FROM user WHERE id = ?
var id int64 var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id) has, err := engine.Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ? // SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string) var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap) has, err := engine.Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ? // SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols)) var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ? // SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
* 检测记录是否存在 * `Exist` 检测记录是否存在
```Go ```Go
has, err := testEngine.Exist(new(RecordExist)) has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1 // SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{ has, err = testEngine.Exist(&RecordExist{
Name: "test1", Name: "test1",
}) })
// SELECT * FROM record_exist WHERE name = ? LIMIT 1 // SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1 // SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ? // select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist() has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1 // SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1 // SELECT * FROM record_exist WHERE name = ? LIMIT 1
``` ```
* 查询多条记录当然可以使用Join和extends来组合使用 * `Find` 查询多条记录当然可以使用Join和extends来组合使用
```Go ```Go
var users []User var users []User
err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users) err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users)
// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10 // SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0
type Detail struct { type Detail struct {
Id int64 Id int64
@ -206,10 +216,10 @@ err := engine.Table("user").Select("user.*, detail.*")
Join("INNER", "detail", "detail.user_id = user.id"). Join("INNER", "detail", "detail.user_id = user.id").
Where("user.name = ?", name).Limit(10, 0). Where("user.name = ?", name).Limit(10, 0).
Find(&users) Find(&users)
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 // SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0
``` ```
* 根据条件遍历数据库,可以有两种方式: Iterate and Rows * `Iterate``Rows` 根据条件遍历数据库,可以有两种方式: Iterate and Rows
```Go ```Go
err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error { err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
@ -218,6 +228,13 @@ err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
}) })
// SELECT * FROM user // SELECT * FROM user
err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
user := bean.(*User)
return nil
})
// SELECT * FROM user Limit 0, 100
// SELECT * FROM user Limit 101, 100
rows, err := engine.Rows(&User{Name:name}) rows, err := engine.Rows(&User{Name:name})
// SELECT * FROM user // SELECT * FROM user
defer rows.Close() defer rows.Close()
@ -227,10 +244,10 @@ for rows.Next() {
} }
``` ```
* 更新数据除非使用Cols,AllCols函数指明默认只更新非空和非0的字段 * `Update` 更新数据除非使用Cols,AllCols函数指明默认只更新非空和非0的字段
```Go ```Go
affected, err := engine.Id(1).Update(&user) affected, err := engine.ID(1).Update(&user)
// UPDATE user SET ... Where id = ? // UPDATE user SET ... Where id = ?
affected, err := engine.Update(&user, &User{Name:name}) affected, err := engine.Update(&user, &User{Name:name})
@ -241,31 +258,50 @@ affected, err := engine.In(ids).Update(&user)
// UPDATE user SET ... Where id IN (?, ?, ?) // UPDATE user SET ... Where id IN (?, ?, ?)
// force update indicated columns by Cols // force update indicated columns by Cols
affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12}) affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ? // UPDATE user SET age = ?, updated=? Where id = ?
// force NOT update indicated columns by Omit // force NOT update indicated columns by Omit
affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12}) affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ? // UPDATE user SET age = ?, updated=? Where id = ?
affected, err := engine.Id(1).AllCols().Update(&user) affected, err := engine.ID(1).AllCols().Update(&user)
// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? // UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ?
``` ```
* 删除记录需要注意删除必须至少有一个条件否则会报错。要清空数据库可以用EmptyTable * `Delete` 删除记录需要注意删除必须至少有一个条件否则会报错。要清空数据库可以用EmptyTable
```Go ```Go
affected, err := engine.Where(...).Delete(&user) affected, err := engine.Where(...).Delete(&user)
// DELETE FROM user Where ... // DELETE FROM user Where ...
affected, err := engine.ID(2).Delete(&user)
// DELETE FROM user Where id = ?
``` ```
* 获取记录条数 * `Count` 获取记录条数
```Go ```Go
counts, err := engine.Count(&user) counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user // SELECT count(*) AS total FROM user
``` ```
* `Sum` 求和函数
```Go
agesFloat64, err := engine.Sum(&user, "age")
// SELECT sum(age) AS total FROM user
agesInt64, err := engine.SumInt(&user, "age")
// SELECT sum(age) AS total FROM user
sumFloat64Slice, err := engine.Sums(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
sumInt64Slice, err := engine.SumsInt(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
```
* 条件编辑器 * 条件编辑器
```Go ```Go
@ -273,6 +309,132 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
// SELECT id, name ... FROM user WHERE a NOT IN (?, ?) AND b IN (?, ?, ?) // SELECT id, name ... FROM user WHERE a NOT IN (?, ?) AND b IN (?, ?, ?)
``` ```
* 在一个Go程中多次操作数据库但没有事务
```Go
session := engine.NewSession()
defer session.Close()
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
return nil
```
* 在一个Go程中有事务
```Go
session := engine.NewSession()
defer session.Close()
// add Begin() before any action
if err := session.Begin(); err != nil {
// if returned then will rollback automatically
return err
}
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
// add Commit() after all actions
return session.Commit()
```
* 事物的简写方法
```Go
res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) {
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return nil, err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return nil, err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return nil, err
}
return nil, nil
})
```
* Context Cache, if enabled, current query result will be cached on session and be used by next same statement on the same session.
```Go
sess := engine.NewSession()
defer sess.Close()
var context = xorm.NewMemoryContextCache()
var c2 ContextGetStruct
has, err := sess.ID(1).ContextCache(context).Get(&c2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c2.Id)
assert.EqualValues(t, "1", c2.Name)
sql, args := sess.LastSQL()
assert.True(t, len(sql) > 0)
assert.True(t, len(args) > 0)
var c3 ContextGetStruct
has, err = sess.ID(1).ContextCache(context).Get(&c3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c3.Id)
assert.EqualValues(t, "1", c3.Name)
sql, args = sess.LastSQL()
assert.True(t, len(sql) == 0)
assert.True(t, len(args) == 0)
```
## 贡献
如果您也想为Xorm贡献您的力量请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)。您也可以加入QQ群 技术帮助和讨论。
群一280360085 (已满)
群二795010183
## Credits
### Contributors
感谢所有的贡献者. [[Contribute](CONTRIBUTING.md)].
<a href="graphs/contributors"><img src="https://opencollective.com/xorm/contributors.svg?width=890&button=false" /></a>
### Backers
感谢我们所有的 backers! 🙏 [[成为 backer](https://opencollective.com/xorm#backer)]
<a href="https://opencollective.com/xorm#backers" target="_blank"><img src="https://opencollective.com/xorm/backers.svg?width=890"></a>
### Sponsors
成为 sponsor 来支持 xorm。您的 logo 将会被显示并被链接到您的网站。 [[成为 sponsor](https://opencollective.com/xorm#sponsor)]
# 案例 # 案例
* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang) * [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
@ -307,13 +469,30 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [go-blog](http://wangcheng.me) - [github.com/easykoo/go-blog](https://github.com/easykoo/go-blog) * [go-blog](http://wangcheng.me) - [github.com/easykoo/go-blog](https://github.com/easykoo/go-blog)
## 讨论
请加入QQ群280360085 进行讨论。 ## 更新日志
## 贡献 * **v0.7.0**
* 修正部分Bug
如果您也想为Xorm贡献您的力量请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md) * **v0.6.6**
* 修正部分Bug
* **v0.6.5**
* 通过 engine.SetSchema 来支持 schema当前仅支持Postgres
* vgo 支持
* 新增 `FindAndCount` 函数
* 通过 `NewEngineWithParams` 支持数据库特别参数
* 修正部分Bug
* **v0.6.4**
* 自动读写分离支持
* Query/QueryString/QueryInterface 支持与 Where/And 合用
* `Get` 支持获取非结构体变量
* `Iterate` 支持 `BufferSize`
* 修正部分Bug
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## LICENSE ## LICENSE

View File

@ -17,11 +17,12 @@ database:
- createdb -p 5432 -e -U postgres xorm_test1 - createdb -p 5432 -e -U postgres xorm_test1
- createdb -p 5432 -e -U postgres xorm_test2 - createdb -p 5432 -e -U postgres xorm_test2
- createdb -p 5432 -e -U postgres xorm_test3 - createdb -p 5432 -e -U postgres xorm_test3
- psql xorm_test postgres -c "create schema xorm"
test: test:
override: override:
# './...' is a relative pattern which means all subdirectories # './...' is a relative pattern which means all subdirectories
- go get -u github.com/wadey/gocovmerge; - go get -u github.com/wadey/gocovmerge
- go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic - go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic
- go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic - go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic
- go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic
@ -30,7 +31,9 @@ test:
- go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic - go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt > coverage.txt - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh

View File

@ -209,10 +209,10 @@ func convertAssign(dest, src interface{}) error {
if src == nil { if src == nil {
dv.Set(reflect.Zero(dv.Type())) dv.Set(reflect.Zero(dv.Type()))
return nil return nil
} else {
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
} }
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src) s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())

View File

@ -172,12 +172,33 @@ type mysql struct {
allowAllFiles bool allowAllFiles bool
allowOldPasswords bool allowOldPasswords bool
clientFoundRows bool clientFoundRows bool
rowFormat string
} }
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri, drivername, dataSourceName)
} }
func (db *mysql) SetParams(params map[string]string) {
rowFormat, ok := params["rowFormat"]
if ok {
var t = strings.ToUpper(rowFormat)
switch t {
case "COMPACT":
fallthrough
case "REDUNDANT":
fallthrough
case "DYNAMIC":
fallthrough
case "COMPRESSED":
db.rowFormat = t
break
default:
break
}
}
}
func (db *mysql) SqlType(c *core.Column) string { func (db *mysql) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
@ -487,6 +508,62 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil return indexes, nil
} }
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += db.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
} else {
sql += col.StringNoPk(db)
}
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if len(charset) == 0 {
charset = db.URI().Charset
}
if len(charset) != 0 {
sql += " DEFAULT CHARSET " + charset
}
if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat
}
return sql
}
func (db *mysql) Filters() []core.Filter { func (db *mysql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}} return []core.Filter{&core.IdFilter{}}
} }

View File

@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"sort"
"strconv" "strconv"
"strings" "strings"
@ -765,14 +764,26 @@ var (
"YES": true, "YES": true,
"ZONE": true, "ZONE": true,
} }
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
) )
const postgresPublicSchema = "public"
type postgres struct { type postgres struct {
core.Base core.Base
} }
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil {
return err
}
if db.Schema == "" {
db.Schema = DefaultPostgresSchema
}
return nil
} }
func (db *postgres) SqlType(c *core.Column) string { func (db *postgres) SqlType(c *core.Column) string {
@ -869,32 +880,42 @@ func (db *postgres) IndexOnTable() bool {
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName} if len(db.Schema) == 0 {
args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} if len(db.Schema) == 0 {
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}
args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
} }
/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args
}*/
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", if len(db.Schema) == 0 {
tableName, col.Name, db.SqlType(col)) return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
}
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col))
} }
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
//var unique string
quote := db.Quote quote := db.Quote
idxName := index.Name idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") && if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") { !strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
@ -903,13 +924,21 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
} }
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName} args := []interface{}{db.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $2" " AND column_name = $3"
if len(db.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
db.LogSQL(query, args) db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...) rows, err := db.DB().Query(query, args...)
@ -922,8 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
// FIXME: the schema should be replaced by user custom's args := []interface{}{tableName}
args := []interface{}{tableName, "public"}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -934,7 +962,15 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
db.LogSQL(s, args) db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
@ -1024,9 +1060,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att
} }
func (db *postgres) GetTables() ([]*core.Table, error) { func (db *postgres) GetTables() ([]*core.Table, error) {
// FIXME: replace public to user customrize schema args := []interface{}{}
args := []interface{}{"public"} s := "SELECT tablename FROM pg_tables"
s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " WHERE schemaname = $1"
}
db.LogSQL(s, args) db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
@ -1050,9 +1090,12 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
// FIXME: replace the public schema to user specify schema args := []interface{}{tableName}
args := []interface{}{"public", tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args) db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
@ -1117,10 +1160,6 @@ func (vs values) Get(k string) (v string) {
return vs[k] return vs[k]
} }
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseURL(connstr string) (string, error) { func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr) u, err := url.Parse(connstr)
if err != nil { if err != nil {
@ -1131,46 +1170,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
} }
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" { if u.Path != "" {
accrue("dbname", u.Path[1:]) return escaper.Replace(u.Path[1:]), nil
} }
q := u.Query() return "", nil
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
} }
func parseOpts(name string, o values) { func parseOpts(name string, o values) error {
if len(name) == 0 { if len(name) == 0 {
return return fmt.Errorf("invalid options: %s", name)
} }
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
@ -1179,31 +1190,48 @@ func parseOpts(name string, o values) {
for _, p := range ps { for _, p := range ps {
kv := strings.Split(p, "=") kv := strings.Split(p, "=")
if len(kv) < 2 { if len(kv) < 2 {
errorf("invalid option: %q", p) return fmt.Errorf("invalid option: %q", p)
} }
o.Set(kv[0], kv[1]) o.Set(kv[0], kv[1])
} }
return nil
} }
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES} db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
var err error var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName) db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname") db.DbName = o.Get("dbname")
}
if db.DbName == "" { if db.DbName == "" {
return nil, errors.New("dbname is empty") return nil, errors.New("dbname is empty")
} }
/*db.Schema = o.Get("schema")
if len(db.Schema) == 0 {
db.Schema = "public"
}*/
return db, nil return db, nil
} }
type pqDriverPgx struct {
pqDriver
}
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) {
// Remove the leading characters for driver to work
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 {
dataSourceName = dataSourceName[9:]
}
return pgx.pqDriver.Parse(driverName, dataSourceName)
}

View File

@ -233,7 +233,7 @@ func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
} }
func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string {
//var unique string // var unique string
quote := db.Quote quote := db.Quote
idxName := index.Name idxName := index.Name
@ -452,5 +452,9 @@ type sqlite3Driver struct {
} }
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
}
return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil
} }

View File

@ -47,6 +47,52 @@ type Engine struct {
disableGlobalCache bool disableGlobalCache bool
tagHandlers map[string]tagHandler tagHandlers map[string]tagHandler
engineGroup *EngineGroup
cachers map[string]core.Cacher
cacherLock sync.RWMutex
}
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
engine.cacherLock.Lock()
engine.cachers[tableName] = cacher
engine.cacherLock.Unlock()
}
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) {
engine.setCacher(tableName, cacher)
}
func (engine *Engine) getCacher(tableName string) core.Cacher {
var cacher core.Cacher
var ok bool
engine.cacherLock.RLock()
cacher, ok = engine.cachers[tableName]
engine.cacherLock.RUnlock()
if !ok && !engine.disableGlobalCache {
cacher = engine.Cacher
}
return cacher
}
func (engine *Engine) GetCacher(tableName string) core.Cacher {
return engine.getCacher(tableName)
}
// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.BufferSize(size)
}
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(colName string) builder.Cond {
if engine.dialect.DBType() == core.MSSQL {
return builder.IsNull{colName}
}
return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
} }
// ShowSQL show SQL statement or not on logger if log level is great than INFO // ShowSQL show SQL statement or not on logger if log level is great than INFO
@ -79,6 +125,11 @@ func (engine *Engine) SetLogger(logger core.ILogger) {
engine.dialect.SetLogger(logger) engine.dialect.SetLogger(logger)
} }
// SetLogLevel sets the logger level
func (engine *Engine) SetLogLevel(level core.LogLevel) {
engine.logger.SetLevel(level)
}
// SetDisableGlobalCache disable global cache or not // SetDisableGlobalCache disable global cache or not
func (engine *Engine) SetDisableGlobalCache(disable bool) { func (engine *Engine) SetDisableGlobalCache(disable bool) {
if engine.disableGlobalCache != disable { if engine.disableGlobalCache != disable {
@ -126,6 +177,14 @@ func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr() return engine.dialect.QuoteStr()
} }
func (engine *Engine) quoteColumns(columnStr string) string {
columns := strings.Split(columnStr, ",")
for i := 0; i < len(columns); i++ {
columns[i] = engine.Quote(strings.TrimSpace(columns[i]))
}
return strings.Join(columns, ",")
}
// Quote Use QuoteStr quote the string sql // Quote Use QuoteStr quote the string sql
func (engine *Engine) Quote(value string) string { func (engine *Engine) Quote(value string) string {
value = strings.TrimSpace(value) value = strings.TrimSpace(value)
@ -143,7 +202,7 @@ func (engine *Engine) Quote(value string) string {
} }
// QuoteTo quotes string and writes into the buffer // QuoteTo quotes string and writes into the buffer
func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
if buf == nil { if buf == nil {
return return
} }
@ -186,6 +245,11 @@ func (engine *Engine) AutoIncrStr() string {
return engine.dialect.AutoIncrStr() return engine.dialect.AutoIncrStr()
} }
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
engine.db.SetConnMaxLifetime(d)
}
// SetMaxOpenConns is only available for go 1.2+ // SetMaxOpenConns is only available for go 1.2+
func (engine *Engine) SetMaxOpenConns(conns int) { func (engine *Engine) SetMaxOpenConns(conns int) {
engine.db.SetMaxOpenConns(conns) engine.db.SetMaxOpenConns(conns)
@ -201,6 +265,11 @@ func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
engine.Cacher = cacher engine.Cacher = cacher
} }
// GetDefaultCacher returns the default cacher
func (engine *Engine) GetDefaultCacher() core.Cacher {
return engine.Cacher
}
// NoCache If you has set default cacher, and you want temporilly stop use cache, // NoCache If you has set default cacher, and you want temporilly stop use cache,
// you can use NoCache() // you can use NoCache()
func (engine *Engine) NoCache() *Session { func (engine *Engine) NoCache() *Session {
@ -218,13 +287,7 @@ func (engine *Engine) NoCascade() *Session {
// MapCacher Set a table use a special cacher // MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
v := rValue(bean) engine.setCacher(engine.TableName(bean, true), cacher)
tb, err := engine.autoMapType(v)
if err != nil {
return err
}
tb.Cacher = cacher
return nil return nil
} }
@ -509,33 +572,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return nil return nil
} }
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
v := rValue(beanOrTableName)
if v.Type().Kind() == reflect.String {
return beanOrTableName.(string), nil
} else if v.Type().Kind() == reflect.Struct {
return engine.tbName(v), nil
}
return "", errors.New("bean should be a struct or struct's point")
}
func (engine *Engine) tbName(v reflect.Value) string {
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
}
if v.Type().Kind() == reflect.Ptr {
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
return tb.TableName()
}
} else if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return tb.TableName()
}
}
return engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name())
}
// Cascade use cascade or not // Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession() session := engine.NewSession()
@ -736,6 +772,13 @@ func (engine *Engine) OrderBy(order string) *Session {
return session.OrderBy(order) return session.OrderBy(order)
} }
// Prepare enables prepare statement
func (engine *Engine) Prepare() *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Prepare()
}
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
@ -757,7 +800,8 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions) return session.Having(conditions)
} }
func (engine *Engine) unMapType(t reflect.Type) { // UnMapType removes the datbase mapper of a type
func (engine *Engine) UnMapType(t reflect.Type) {
engine.mutex.Lock() engine.mutex.Lock()
defer engine.mutex.Unlock() defer engine.mutex.Unlock()
delete(engine.Tables, t) delete(engine.Tables, t)
@ -811,7 +855,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
} }
return &Table{tb, engine.tbName(v)} return &Table{tb, engine.TableName(bean)}
} }
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@ -826,15 +870,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i
} }
} }
func (engine *Engine) newTable() *core.Table {
table := core.NewEmptyTable()
if !engine.disableGlobalCache {
table.Cacher = engine.Cacher
}
return table
}
// TableName table name interface to define customerize table name // TableName table name interface to define customerize table name
type TableName interface { type TableName interface {
TableName() string TableName() string
@ -846,21 +881,9 @@ var (
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type() t := v.Type()
table := engine.newTable() table := core.NewEmptyTable()
if tb, ok := v.Interface().(TableName); ok {
table.Name = tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
table.Name = tb.TableName()
}
}
if table.Name == "" {
table.Name = engine.TableMapper.Obj2Table(t.Name())
}
}
table.Type = t table.Type = t
table.Name = engine.tbNameForMap(v)
var idFieldColName string var idFieldColName string
var hasCacheTag, hasNoCacheTag bool var hasCacheTag, hasNoCacheTag bool
@ -914,7 +937,7 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
} }
if pStart > -1 { if pStart > -1 {
if !strings.HasSuffix(k, ")") { if !strings.HasSuffix(k, ")") {
return nil, errors.New("cannot match ) charactor") return nil, fmt.Errorf("field %s tag %s cannot match ) charactor", col.FieldName, key)
} }
ctx.tagName = k[:pStart] ctx.tagName = k[:pStart]
@ -1014,15 +1037,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
if hasCacheTag { if hasCacheTag {
if engine.Cacher != nil { // !nash! use engine's cacher if provided if engine.Cacher != nil { // !nash! use engine's cacher if provided
engine.logger.Info("enable cache on table:", table.Name) engine.logger.Info("enable cache on table:", table.Name)
table.Cacher = engine.Cacher engine.setCacher(table.Name, engine.Cacher)
} else { } else {
engine.logger.Info("enable LRU cache on table:", table.Name) engine.logger.Info("enable LRU cache on table:", table.Name)
table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000))
} }
} }
if hasNoCacheTag { if hasNoCacheTag {
engine.logger.Info("no cache on table:", table.Name) engine.logger.Info("disable cache on table:", table.Name)
table.Cacher = nil engine.setCacher(table.Name, nil)
} }
return table, nil return table, nil
@ -1081,7 +1104,25 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
pk := make([]interface{}, len(table.PrimaryKeys)) pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
var err error var err error
pkField := v.FieldByName(col.FieldName)
fieldName := col.FieldName
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, ErrUnSupportedType
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() { switch pkField.Kind() {
case reflect.String: case reflect.String:
pk[i], err = engine.idTypeAssertion(col, pkField.String()) pk[i], err = engine.idTypeAssertion(col, pkField.String())
@ -1127,26 +1168,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
return session.CreateUniques(bean) return session.CreateUniques(bean)
} }
func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
return table.Cacher
}
// ClearCacheBean if enabled cache, clear the cache bean // ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
v := rValue(bean) tableName := engine.TableName(bean)
t := v.Type() cacher := engine.getCacher(tableName)
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
if cacher != nil { if cacher != nil {
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.DelBean(tableName, id) cacher.DelBean(tableName, id)
@ -1157,21 +1182,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache // ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error { func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) tableName := engine.TableName(bean)
t := v.Type() cacher := engine.getCacher(tableName)
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
if cacher != nil { if cacher != nil {
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.ClearBeans(tableName) cacher.ClearBeans(tableName)
@ -1189,13 +1201,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableName := engine.tbName(v) tableNameNoSchema := engine.TableName(bean)
table, err := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil { if err != nil {
return err return err
} }
isExist, err := session.Table(bean).isTableExist(tableName) isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
if err != nil { if err != nil {
return err return err
} }
@ -1221,12 +1233,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
} else { } else {
for _, col := range table.Columns() { for _, col := range table.Columns() {
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefValue(v); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return err return err
} }
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
@ -1237,35 +1249,35 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
for name, index := range table.Indexes { for name, index := range table.Indexes {
if err := session.statement.setRefValue(v); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return err return err
} }
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
isExist, err := session.isIndexExist2(tableName, index.Cols, true) isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefValue(v); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return err return err
} }
err = session.addUnique(tableName, name) err = session.addUnique(tableNameNoSchema, name)
if err != nil { if err != nil {
return err return err
} }
} }
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(tableName, index.Cols, false) isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefValue(v); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return err return err
} }
err = session.addIndex(tableName, name) err = session.addIndex(tableNameNoSchema, name)
if err != nil { if err != nil {
return err return err
} }
@ -1334,31 +1346,31 @@ func (engine *Engine) DropIndexes(bean interface{}) error {
} }
// Exec raw sql // Exec raw sql
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { func (engine *Engine) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Exec(sql, args...) return session.Exec(sqlorArgs...)
} }
// Query a raw sql and return records as []map[string][]byte // Query a raw sql and return records as []map[string][]byte
func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (engine *Engine) Query(sqlorArgs ...interface{}) (resultsSlice []map[string][]byte, err error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Query(sql, paramStr...) return session.Query(sqlorArgs...)
} }
// QueryString runs a raw sql and return records as []map[string]string // QueryString runs a raw sql and return records as []map[string]string
func (engine *Engine) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) { func (engine *Engine) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.QueryString(sqlStr, args...) return session.QueryString(sqlorArgs...)
} }
// QueryInterface runs a raw sql and return records as []map[string]interface{} // QueryInterface runs a raw sql and return records as []map[string]interface{}
func (engine *Engine) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) { func (engine *Engine) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.QueryInterface(sqlStr, args...) return session.QueryInterface(sqlorArgs...)
} }
// Insert one or more records // Insert one or more records
@ -1418,6 +1430,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
return session.Find(beans, condiBeans...) return session.Find(beans, condiBeans...)
} }
// FindAndCount find the results and also return the counts
func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.FindAndCount(rowsSlicePtr, condiBean...)
}
// Iterate record by record handle records from table, bean's non-empty fields // Iterate record by record handle records from table, bean's non-empty fields
// are conditions. // are conditions.
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error {
@ -1564,24 +1583,44 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}
return return
} }
// GetColumnMapper returns the column name mapper
func (engine *Engine) GetColumnMapper() core.IMapper {
return engine.ColumnMapper
}
// GetTableMapper returns the table name mapper
func (engine *Engine) GetTableMapper() core.IMapper {
return engine.TableMapper
}
// GetTZLocation returns time zone of the application
func (engine *Engine) GetTZLocation() *time.Location {
return engine.TZLocation
}
// SetTZLocation sets time zone of the application
func (engine *Engine) SetTZLocation(tz *time.Location) {
engine.TZLocation = tz
}
// GetTZDatabase returns time zone of the database
func (engine *Engine) GetTZDatabase() *time.Location {
return engine.DatabaseTZ
}
// SetTZDatabase sets time zone of the database
func (engine *Engine) SetTZDatabase(tz *time.Location) {
engine.DatabaseTZ = tz
}
// SetSchema sets the schema of database
func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().Schema = schema
}
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session { func (engine *Engine) Unscoped() *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.Unscoped() return session.Unscoped()
} }
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(colName string) builder.Cond {
if engine.dialect.DBType() == core.MSSQL {
return builder.IsNull{colName}
}
return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
}
// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.BufferSize(size)
}

View File

@ -9,6 +9,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
"github.com/go-xorm/builder" "github.com/go-xorm/builder"
@ -51,7 +52,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
fieldValuePtr, err := col.ValueOf(bean) fieldValuePtr, err := col.ValueOf(bean)
if err != nil { if err != nil {
engine.logger.Error(err) if !strings.Contains(err.Error(), "is not valid") {
engine.logger.Warn(err)
}
continue continue
} }

View File

@ -1,14 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.6
package xorm
import "time"
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
engine.db.SetConnMaxLifetime(d)
}

View File

@ -6,21 +6,44 @@ package xorm
import ( import (
"errors" "errors"
"fmt"
) )
var ( var (
// ErrParamsType params error // ErrParamsType params error
ErrParamsType = errors.New("Params type error") ErrParamsType = errors.New("Params type error")
// ErrTableNotFound table not found error // ErrTableNotFound table not found error
ErrTableNotFound = errors.New("Not found table") ErrTableNotFound = errors.New("Table not found")
// ErrUnSupportedType unsupported error // ErrUnSupportedType unsupported error
ErrUnSupportedType = errors.New("Unsupported type error") ErrUnSupportedType = errors.New("Unsupported type error")
// ErrNotExist record is not exist error // ErrNotExist record does not exist error
ErrNotExist = errors.New("Not exist error") ErrNotExist = errors.New("Record does not exist")
// ErrCacheFailed cache failed error // ErrCacheFailed cache failed error
ErrCacheFailed = errors.New("Cache failed") ErrCacheFailed = errors.New("Cache failed")
// ErrNeedDeletedCond delete needs less one condition error // ErrNeedDeletedCond delete needs less one condition error
ErrNeedDeletedCond = errors.New("Delete need at least one condition") ErrNeedDeletedCond = errors.New("Delete action needs at least one condition")
// ErrNotImplemented not implemented // ErrNotImplemented not implemented
ErrNotImplemented = errors.New("Not implemented") ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type")
) )
// ErrFieldIsNotExist columns does not exist
type ErrFieldIsNotExist struct {
FieldName string
TableName string
}
func (e ErrFieldIsNotExist) Error() string {
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName)
}
// ErrFieldIsNotValid is not valid
type ErrFieldIsNotValid struct {
FieldName string
TableName string
}
func (e ErrFieldIsNotValid) Error() string {
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName)
}

View File

@ -11,7 +11,6 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
@ -293,19 +292,6 @@ func structName(v reflect.Type) string {
return v.Name() return v.Name()
} }
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
func sliceEq(left, right []string) bool { func sliceEq(left, right []string) bool {
if len(left) != len(right) { if len(left) != len(right) {
return false return false
@ -320,154 +306,6 @@ func sliceEq(left, right []string) bool {
return true return true
} }
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
if col.IsDeleted {
continue
}
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
if includeQuote {
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
}
return colNames, args, nil
}
func indexName(tableName, idxName string) string { func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
} }
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}

View File

@ -1,214 +0,0 @@
package migrate
import (
"errors"
"fmt"
"github.com/go-xorm/xorm"
)
// MigrateFunc is the func signature for migrating.
type MigrateFunc func(*xorm.Engine) error
// RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*xorm.Engine) error
// InitSchemaFunc is the func signature for initializing the schema.
type InitSchemaFunc func(*xorm.Engine) error
// Options define options for all migrations.
type Options struct {
// TableName is the migration table.
TableName string
// IDColumnName is the name of column where the migration id will be stored.
IDColumnName string
}
// Migration represents a database migration (a modification to be made on the database).
type Migration struct {
// ID is the migration identifier. Usually a timestamp like "201601021504".
ID string
// Migrate is a function that will br executed while running this migration.
Migrate MigrateFunc
// Rollback will be executed on rollback. Can be nil.
Rollback RollbackFunc
}
// Migrate represents a collection of all migrations of a database schema.
type Migrate struct {
db *xorm.Engine
options *Options
migrations []*Migration
initSchema InitSchemaFunc
}
var (
// DefaultOptions can be used if you don't want to think about options.
DefaultOptions = &Options{
TableName: "migrations",
IDColumnName: "id",
}
// ErrRollbackImpossible is returned when trying to rollback a migration
// that has no rollback function.
ErrRollbackImpossible = errors.New("It's impossible to rollback this migration")
// ErrNoMigrationDefined is returned when no migration is defined.
ErrNoMigrationDefined = errors.New("No migration defined")
// ErrMissingID is returned when the ID od migration is equal to ""
ErrMissingID = errors.New("Missing ID in migration")
// ErrNoRunnedMigration is returned when any runned migration was found while
// running RollbackLast
ErrNoRunnedMigration = errors.New("Could not find last runned migration")
)
// New returns a new Gormigrate.
func New(db *xorm.Engine, options *Options, migrations []*Migration) *Migrate {
return &Migrate{
db: db,
options: options,
migrations: migrations,
}
}
// InitSchema sets a function that is run if no migration is found.
// The idea is preventing to run all migrations when a new clean database
// is being migrating. In this function you should create all tables and
// foreign key necessary to your application.
func (m *Migrate) InitSchema(initSchema InitSchemaFunc) {
m.initSchema = initSchema
}
// Migrate executes all migrations that did not run yet.
func (m *Migrate) Migrate() error {
if err := m.createMigrationTableIfNotExists(); err != nil {
return err
}
if m.initSchema != nil && m.isFirstRun() {
if err := m.runInitSchema(); err != nil {
return err
}
return nil
}
for _, migration := range m.migrations {
if err := m.runMigration(migration); err != nil {
return err
}
}
return nil
}
// RollbackLast undo the last migration
func (m *Migrate) RollbackLast() error {
if len(m.migrations) == 0 {
return ErrNoMigrationDefined
}
lastRunnedMigration, err := m.getLastRunnedMigration()
if err != nil {
return err
}
if err := m.RollbackMigration(lastRunnedMigration); err != nil {
return err
}
return nil
}
func (m *Migrate) getLastRunnedMigration() (*Migration, error) {
for i := len(m.migrations) - 1; i >= 0; i-- {
migration := m.migrations[i]
if m.migrationDidRun(migration) {
return migration, nil
}
}
return nil, ErrNoRunnedMigration
}
// RollbackMigration undo a migration.
func (m *Migrate) RollbackMigration(mig *Migration) error {
if mig.Rollback == nil {
return ErrRollbackImpossible
}
if err := mig.Rollback(m.db); err != nil {
return err
}
sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName)
if _, err := m.db.Exec(sql, mig.ID); err != nil {
return err
}
return nil
}
func (m *Migrate) runInitSchema() error {
if err := m.initSchema(m.db); err != nil {
return err
}
for _, migration := range m.migrations {
if err := m.insertMigration(migration.ID); err != nil {
return err
}
}
return nil
}
func (m *Migrate) runMigration(migration *Migration) error {
if len(migration.ID) == 0 {
return ErrMissingID
}
if !m.migrationDidRun(migration) {
if err := migration.Migrate(m.db); err != nil {
return err
}
if err := m.insertMigration(migration.ID); err != nil {
return err
}
}
return nil
}
func (m *Migrate) createMigrationTableIfNotExists() error {
exists, err := m.db.IsTableExist(m.options.TableName)
if err != nil {
return err
}
if exists {
return nil
}
sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(255) PRIMARY KEY)", m.options.TableName, m.options.IDColumnName)
if _, err := m.db.Exec(sql); err != nil {
return err
}
return nil
}
func (m *Migrate) migrationDidRun(mig *Migration) bool {
row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID)
var count int
row.Scan(&count)
return count > 0
}
func (m *Migrate) isFirstRun() bool {
row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", m.options.TableName))
var count int
row.Scan(&count)
return count == 0
}
func (m *Migrate) insertMigration(id string) error {
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?)", m.options.TableName, m.options.IDColumnName)
_, err := m.db.Exec(sql, id)
return err
}

View File

@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
var args []interface{} var args []interface{}
var err error var err error
if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { if err = rows.session.statement.setRefBean(bean); err != nil {
return nil, err return nil, err
} }
@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
} }
dataStruct := rValue(bean) if err := rows.session.statement.setRefBean(bean); err != nil {
if err := rows.session.statement.setRefValue(dataStruct); err != nil {
return err return err
} }
@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
dataStruct := rValue(bean)
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil { if err != nil {
return err return err

View File

@ -76,6 +76,7 @@ func (session *Session) Init() {
session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0)
session.beforeClosures = make([]func(interface{}), 0) session.beforeClosures = make([]func(interface{}), 0)
session.afterClosures = make([]func(interface{}), 0) session.afterClosures = make([]func(interface{}), 0)
session.stmtCache = make(map[uint32]*core.Stmt)
session.afterProcessors = make([]executedProcessor, 0) session.afterProcessors = make([]executedProcessor, 0)
@ -101,6 +102,12 @@ func (session *Session) Close() {
} }
} }
// ContextCache enable context cache or not
func (session *Session) ContextCache(context ContextCache) *Session {
session.statement.context = context
return session
}
// IsClosed returns if session is closed // IsClosed returns if session is closed
func (session *Session) IsClosed() bool { func (session *Session) IsClosed() bool {
return session.db == nil return session.db == nil
@ -262,13 +269,13 @@ func (session *Session) canCache() bool {
return true return true
} }
func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) { func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, err error) {
crc := crc32.ChecksumIEEE([]byte(sqlStr)) crc := crc32.ChecksumIEEE([]byte(sqlStr))
// TODO try hash(sqlStr+len(sqlStr)) // TODO try hash(sqlStr+len(sqlStr))
var has bool var has bool
stmt, has = session.stmtCache[crc] stmt, has = session.stmtCache[crc]
if !has { if !has {
stmt, err = session.DB().Prepare(sqlStr) stmt, err = db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -277,24 +284,22 @@ func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
return return
} }
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) {
var col *core.Column var col *core.Column
if col = table.GetColumnIdx(key, idx); col == nil { if col = table.GetColumnIdx(key, idx); col == nil {
//session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) return nil, ErrFieldIsNotExist{key, table.Name}
return nil
} }
fieldValue, err := col.ValueOfV(dataStruct) fieldValue, err := col.ValueOfV(dataStruct)
if err != nil { if err != nil {
session.engine.logger.Error(err) return nil, err
return nil
} }
if !fieldValue.IsValid() || !fieldValue.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) return nil, ErrFieldIsNotValid{key, table.Name}
return nil
} }
return fieldValue
return fieldValue, nil
} }
// Cell cell is a result of one column field // Cell cell is a result of one column field
@ -406,405 +411,417 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
tempMap[lKey] = idx tempMap[lKey] = idx
if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil { fieldValue, err := session.getField(dataStruct, key, table, idx)
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
// if row is null then ignore session.engine.logger.Warn(err)
if rawValue.Interface() == nil {
continue
} }
continue
}
if fieldValue == nil {
continue
}
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
if fieldValue.CanAddr() { // if row is null then ignore
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if rawValue.Interface() == nil {
if data, err := value2Bytes(&rawValue); err == nil { continue
if err := structConvert.FromDB(data); err != nil { }
return nil, err
} if fieldValue.CanAddr() {
} else { if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if err := structConvert.FromDB(data); err != nil {
return nil, err return nil, err
} }
continue
}
}
if _, ok := fieldValue.Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue.Interface().(core.Conversion).FromDB(data)
} else { } else {
return nil, err return nil, err
} }
continue continue
} }
}
rawValueType := reflect.TypeOf(rawValue.Interface()) if _, ok := fieldValue.Interface().(core.Conversion); ok {
vv := reflect.ValueOf(rawValue.Interface()) if data, err := value2Bytes(&rawValue); err == nil {
col := table.GetColumnIdx(key, idx) if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
if col.IsPrimaryKey { fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
pk = append(pk, rawValue.Interface()) }
fieldValue.Interface().(core.Conversion).FromDB(data)
} else {
return nil, err
} }
fieldType := fieldValue.Type() continue
hasAssigned := false }
if col.SQLType.IsJson() { rawValueType := reflect.TypeOf(rawValue.Interface())
var bs []byte vv := reflect.ValueOf(rawValue.Interface())
if rawValueType.Kind() == reflect.String { col := table.GetColumnIdx(key, idx)
bs = []byte(vv.String()) if col.IsPrimaryKey {
} else if rawValueType.ConvertibleTo(core.BytesType) { pk = append(pk, rawValue.Interface())
bs = vv.Bytes() }
fieldType := fieldValue.Type()
hasAssigned := false
if col.SQLType.IsJson() {
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes()
} else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
}
hasAssigned = true
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
continue
}
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else { } else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) x := reflect.New(fieldType)
} err := json.Unmarshal(bs, x.Interface())
if err != nil {
hasAssigned = true return nil, err
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
} }
fieldValue.Set(x.Elem())
} }
continue
} }
switch fieldType.Kind() { continue
case reflect.Complex64, reflect.Complex128: }
// TODO: reimplement this
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes()
}
hasAssigned = true switch fieldType.Kind() {
if len(bs) > 0 { case reflect.Complex64, reflect.Complex128:
if fieldValue.CanAddr() { // TODO: reimplement this
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) var bs []byte
if err != nil { if rawValueType.Kind() == reflect.String {
return nil, err bs = []byte(vv.String())
} } else if rawValueType.ConvertibleTo(core.BytesType) {
} else { bs = vv.Bytes()
x := reflect.New(fieldType) }
err := json.Unmarshal(bs, x.Interface())
if err != nil { hasAssigned = true
return nil, err if len(bs) > 0 {
} if fieldValue.CanAddr() {
fieldValue.Set(x.Elem()) err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
} }
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
} }
}
case reflect.Slice, reflect.Array:
switch rawValueType.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch rawValueType.Kind() { switch rawValueType.Elem().Kind() {
case reflect.Slice, reflect.Array: case reflect.Uint8:
switch rawValueType.Elem().Kind() { if fieldType.Elem().Kind() == reflect.Uint8 {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
hasAssigned = true
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
} else {
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
}
}
}
}
}
case reflect.String:
if rawValueType.Kind() == reflect.String {
hasAssigned = true
fieldValue.SetString(vv.String())
}
case reflect.Bool:
if rawValueType.Kind() == reflect.Bool {
hasAssigned = true
fieldValue.SetBool(vv.Bool())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetInt(vv.Int())
}
case reflect.Float32, reflect.Float64:
switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64:
hasAssigned = true
fieldValue.SetFloat(vv.Float())
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
hasAssigned = true
fieldValue.SetUint(vv.Uint())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetUint(uint64(vv.Int()))
}
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
if rawValueType == core.TimeType {
hasAssigned = true hasAssigned = true
if col.SQLType.IsText() {
t := vv.Convert(core.TimeType).Interface().(time.Time) x := reflect.New(fieldType)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else if d, ok := vv.Interface().(string); ok {
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else {
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
err := json.Unmarshal(vv.Bytes(), x.Interface()) err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
}
}
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return nil, err
}
hasAssigned = true
if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType)
if err != nil {
return nil, err
}
if !isPKZero(pk) {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
fieldValue.Set(structInter.Elem())
} else { } else {
return nil, errors.New("cascade obj is not exist") if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
}
} }
} }
} }
case reflect.Ptr: }
// !nashtsai! TODO merge duplicated codes above case reflect.String:
switch fieldType { if rawValueType.Kind() == reflect.String {
// following types case matching ptr's native type, therefore assign ptr directly hasAssigned = true
case core.PtrStringType: fieldValue.SetString(vv.String())
if rawValueType.Kind() == reflect.String { }
x := vv.String() case reflect.Bool:
hasAssigned = true if rawValueType.Kind() == reflect.Bool {
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
} fieldValue.SetBool(vv.Bool())
case core.PtrBoolType: }
if rawValueType.Kind() == reflect.Bool { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x := vv.Bool() switch rawValueType.Kind() {
hasAssigned = true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
} fieldValue.SetInt(vv.Int())
case core.PtrTimeType: }
if rawValueType == core.PtrTimeType { case reflect.Float32, reflect.Float64:
hasAssigned = true switch rawValueType.Kind() {
var x = rawValue.Interface().(time.Time) case reflect.Float32, reflect.Float64:
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
} fieldValue.SetFloat(vv.Float())
case core.PtrFloat64Type: }
if rawValueType.Kind() == reflect.Float64 { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
x := vv.Float() switch rawValueType.Kind() {
hasAssigned = true case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
} fieldValue.SetUint(vv.Uint())
case core.PtrUint64Type: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rawValueType.Kind() == reflect.Int64 { hasAssigned = true
var x = uint64(vv.Int()) fieldValue.SetUint(uint64(vv.Int()))
hasAssigned = true }
fieldValue.Set(reflect.ValueOf(&x)) case reflect.Struct:
} if fieldType.ConvertibleTo(core.TimeType) {
case core.PtrInt64Type: dbTZ := session.engine.DatabaseTZ
if rawValueType.Kind() == reflect.Int64 { if col.TimeZone != nil {
x := vv.Int() dbTZ = col.TimeZone
hasAssigned = true }
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case core.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
} // switch fieldType
} // switch fieldType.Kind()
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value if rawValueType == core.TimeType {
if !hasAssigned { hasAssigned = true
data, err := value2Bytes(&rawValue)
t := vv.Convert(core.TimeType).Interface().(time.Time)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else if d, ok := vv.Interface().(string); ok {
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else {
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
}
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = session.bytes2Value(col, fieldValue, data); err != nil { hasAssigned = true
if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType)
if err != nil {
return nil, err return nil, err
} }
if !isPKZero(pk) {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
fieldValue.Set(structInter.Elem())
} else {
return nil, errors.New("cascade obj is not exist")
}
}
}
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrTimeType:
if rawValueType == core.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case core.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
} // switch fieldType
} // switch fieldType.Kind()
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value
if !hasAssigned {
data, err := value2Bytes(&rawValue)
if err != nil {
return nil, err
}
if err = session.bytes2Value(col, fieldValue, data); err != nil {
return nil, err
} }
} }
} }
@ -823,15 +840,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
return session.lastSQL, session.lastSQLArgs return session.lastSQL, session.lastSQLArgs
} }
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session { func (session *Session) Unscoped() *Session {
session.statement.Unscoped() session.statement.Unscoped()

View File

@ -4,6 +4,121 @@
package xorm package xorm
import (
"reflect"
"strings"
"time"
"github.com/go-xorm/core"
)
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
type columnMap []string
func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}
n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}
return false
}
func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Incr provides a query string like "count = count + 1" // Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session { func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.statement.Incr(column, arg...) session.statement.Incr(column, arg...)

View File

@ -34,27 +34,27 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
sd, err := strconv.ParseInt(sdata, 10, 64) sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil { if err == nil {
x = time.Unix(sd, 0) x = time.Unix(sd, 0)
session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else { } else {
session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
} else if len(sdata) > 19 && strings.Contains(sdata, "-") { } else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
if err != nil { if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
if err != nil { if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
} else if len(sdata) == 19 && strings.Contains(sdata, "-") { } else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time { } else if col.SQLType.Name == core.Time {
if strings.Contains(sdata, " ") { if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ") ssd := strings.Split(sdata, " ")
@ -68,7 +68,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
st := fmt.Sprintf("2006-01-02 %v", sdata) st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else { } else {
outErr = fmt.Errorf("unsupported time format %v", sdata) outErr = fmt.Errorf("unsupported time format %v", sdata)
return return

View File

@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed return ErrCacheFailed
} }
cacher := session.engine.getCacher2(table) cacher := session.engine.getCacher(tableName)
pkColumns := table.PKColumns() pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.Close() defer session.Close()
} }
if err := session.statement.setRefValue(rValue(bean)); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return 0, err return 0, err
} }
@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}) })
} }
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
} }

View File

@ -10,6 +10,7 @@ import (
"reflect" "reflect"
"github.com/go-xorm/builder" "github.com/go-xorm/builder"
"github.com/go-xorm/core"
) )
// Exist returns true if the record exist otherwise return false // Exist returns true if the record exist otherwise return false
@ -35,10 +36,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, err return false, err
} }
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s WHERE %s", tableName, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
}
args = condArgs args = condArgs
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s", tableName)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
}
args = []interface{}{} args = []interface{}{}
} }
} else { } else {
@ -48,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil { if err := session.statement.setRefBean(bean[0]); err != nil {
return false, err return false, err
} }
} }

View File

@ -29,6 +29,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return session.find(rowsSlicePtr, condiBean...) return session.find(rowsSlicePtr, condiBean...)
} }
// FindAndCount find the results and also return the counts
func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
err := session.find(rowsSlicePtr, condiBean...)
if err != nil {
return 0, err
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return 0, errors.New("needs a pointer to a slice or a map")
}
sliceElementType := sliceValue.Type().Elem()
if sliceElementType.Kind() == reflect.Ptr {
sliceElementType = sliceElementType.Elem()
}
session.autoResetStatement = true
if session.statement.selectStr != "" {
session.statement.selectStr = ""
}
if session.statement.OrderStr != "" {
session.statement.OrderStr = ""
}
return session.Count(reflect.New(sliceElementType).Interface())
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
@ -42,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
if err := session.statement.setRefValue(pv.Elem()); err != nil { if err := session.statement.setRefValue(pv); err != nil {
return err return err
} }
} else { } else {
@ -50,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType) pv := reflect.New(sliceElementType)
if err := session.statement.setRefValue(pv.Elem()); err != nil { if err := session.statement.setRefValue(pv); err != nil {
return err return err
} }
} else { } else {
@ -102,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if session.statement.JoinStr == "" { if session.statement.JoinStr == "" {
if columnStr == "" { if columnStr == "" {
if session.statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else { } else {
columnStr = session.statement.genColumnStr() columnStr = session.statement.genColumnStr()
} }
@ -110,7 +143,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} else { } else {
if columnStr == "" { if columnStr == "" {
if session.statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else { } else {
columnStr = "*" columnStr = "*"
} }
@ -128,7 +161,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
args = append(session.statement.joinArgs, condArgs...) args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil { if err != nil {
return err return err
} }
@ -143,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
if session.canCache() { if session.canCache() {
if cacher := session.engine.getCacher2(table); cacher != nil && if cacher := session.engine.getCacher(table.Name); cacher != nil &&
!session.statement.IsDistinct && !session.statement.IsDistinct &&
!session.statement.unscoped { !session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
@ -288,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed return ErrCacheFailed
} }
tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
if cacher == nil {
return nil
}
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
} }
@ -297,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed return ErrCacheFailed
} }
tableName := session.statement.TableName()
table := session.statement.RefTable table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
rows, err := session.queryRows(newsql, args...) rows, err := session.queryRows(newsql, args...)

View File

@ -5,7 +5,9 @@
package xorm package xorm
import ( import (
"database/sql"
"errors" "errors"
"fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -30,7 +32,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return false, err return false, err
} }
} }
@ -56,7 +58,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable table := session.statement.RefTable
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.getCacher2(table); cacher != nil && if cacher := session.engine.getCacher(table.Name); cacher != nil &&
!session.statement.unscoped { !session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
@ -65,7 +67,28 @@ func (session *Session) get(bean interface{}) (bool, error) {
} }
} }
return session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) context := session.statement.context
if context != nil {
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
if res != nil {
structValue := reflect.Indirect(reflect.ValueOf(bean))
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
session.lastSQL = ""
session.lastSQLArgs = nil
return true, nil
}
}
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
if err != nil || !has {
return has, err
}
if context != nil {
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean)
}
return true, nil
} }
func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
@ -76,9 +99,19 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
defer rows.Close() defer rows.Close()
if !rows.Next() { if !rows.Next() {
if rows.Err() != nil {
return false, rows.Err()
}
return false, nil return false, nil
} }
switch bean.(type) {
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString:
return true, rows.Scan(&bean)
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString:
return true, rows.Scan(bean)
}
switch beanKind { switch beanKind {
case reflect.Struct: case reflect.Struct:
fields, err := rows.Columns() fields, err := rows.Columns()
@ -126,8 +159,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed return false, ErrCacheFailed
} }
cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName() tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)

View File

@ -66,11 +66,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice") return 0, errors.New("could not insert a empty slice")
} }
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err return 0, err
} }
if len(session.statement.TableName()) <= 0 { tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
@ -115,15 +116,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.statement.ColumnStr != "" { if session.statement.omitColumnMap.contain(col.Name) {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { continue
continue
}
} }
if session.statement.OmitStr != "" { if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { continue
continue
}
} }
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
@ -170,15 +167,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.statement.ColumnStr != "" { if session.statement.omitColumnMap.contain(col.Name) {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { continue
continue
}
} }
if session.statement.OmitStr != "" { if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { continue
continue
}
} }
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
@ -211,38 +204,33 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
} }
cleanupProcessorsClosures(&session.beforeClosures) cleanupProcessorsClosures(&session.beforeClosures)
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" var sql string
var statement string
var tableName = session.statement.TableName()
if session.engine.dialect.DBType() == core.ORACLE { if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr()) session.engine.QuoteStr())
statement = fmt.Sprintf(sql, sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp)) strings.Join(colMultiPlaces, temp))
} else { } else {
statement = fmt.Sprintf(sql, sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(), session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),(")) strings.Join(colMultiPlaces, "),("))
} }
res, err := session.exec(statement, args...) res, err := session.exec(sql, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { session.cacheInsert(tableName)
session.cacheInsert(table, tableName)
}
lenAfterClosures := len(session.afterClosures) lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
@ -298,7 +286,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
} }
func (session *Session) innerInsert(bean interface{}) (int64, error) { func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefValue(rValue(bean)); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return 0, err return 0, err
} }
if len(session.statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
@ -316,8 +304,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert() processor.BeforeInsert()
} }
// --
colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false) colNames, args, err := session.genInsertColumns(bean)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -400,11 +388,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { session.cacheInsert(tableName)
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
@ -445,11 +431,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { session.cacheInsert(tableName)
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
@ -490,9 +474,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { session.cacheInsert(tableName)
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
@ -539,16 +521,104 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return session.innerInsert(bean) return session.innerInsert(bean)
} }
func (session *Session) cacheInsert(table *core.Table, tables ...string) error { func (session *Session) cacheInsert(table string) error {
if table == nil { if !session.statement.UseCache {
return ErrCacheFailed return nil
} }
cacher := session.engine.getCacher(table)
cacher := session.engine.getCacher2(table) if cacher == nil {
for _, t := range tables { return nil
session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t)
} }
session.engine.logger.Debug("[cache] clear sql:", table)
cacher.ClearIds(table)
return nil return nil
} }
// genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == core.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue
}
if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, col.Name)
}
return colNames, args, nil
}

View File

@ -8,17 +8,86 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) {
if len(sqlorArgs) > 0 {
return convertSQLOrArgs(sqlorArgs...)
}
if session.statement.RawSQL != "" {
return session.statement.RawSQL, session.statement.RawParams, nil
}
if len(session.statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = session.statement.ColumnStr
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := session.statement.processIDParam(); err != nil {
return "", nil, err
}
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return "", nil, err
}
args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}
// Query runs a raw sql and return records as []map[string][]byte // Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { func (session *Session) Query(sqlorArgs ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
return session.queryBytes(sqlStr, args...) return session.queryBytes(sqlStr, args...)
} }
@ -97,6 +166,34 @@ func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string,
return result, nil return result, nil
} }
func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) {
result := make([]string, 0, len(fields))
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for i := 0; i < len(fields); i++ {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i]))
// if row is null then as empty string
if rawValue.Interface() == nil {
result = append(result, "")
continue
}
if data, err := value2String(&rawValue); err == nil {
result = append(result, data)
} else {
return nil, err
}
}
return result, nil
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
@ -113,12 +210,33 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error)
return resultsSlice, nil return resultsSlice, nil
} }
func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
record, err := row2sliceStr(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, record)
}
return resultsSlice, nil
}
// QueryString runs a raw sql and return records as []map[string]string // QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) { func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -128,6 +246,26 @@ func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[s
return rows2Strings(rows) return rows2Strings(rows)
} }
// QuerySliceString runs a raw sql and return records as [][]string
func (session *Session) QuerySliceString(sqlorArgs ...interface{}) ([][]string, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2SliceString(rows)
}
func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) {
resultsMap = make(map[string]interface{}, len(fields)) resultsMap = make(map[string]interface{}, len(fields))
scanResultContainers := make([]interface{}, len(fields)) scanResultContainers := make([]interface{}, len(fields))
@ -162,11 +300,16 @@ func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, er
} }
// QueryInterface runs a raw sql and return records as []map[string]interface{} // QueryInterface runs a raw sql and return records as []map[string]interface{}
func (session *Session) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) { func (session *Session) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -9,6 +9,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
@ -47,9 +48,16 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
} }
if session.isAutoCommit { if session.isAutoCommit {
var db *core.DB
if session.engine.engineGroup != nil {
db = session.engine.engineGroup.Slave().DB()
} else {
db = session.DB()
}
if session.prepareStmt { if session.prepareStmt {
// don't clear stmt since session will cache them // don't clear stmt since session will cache them
stmt, err := session.doPrepare(sqlStr) stmt, err := session.doPrepare(db, sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,7 +69,7 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
return rows, nil return rows, nil
} }
rows, err := session.DB().Query(sqlStr, args...) rows, err := db.Query(sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -171,7 +179,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
} }
if session.prepareStmt { if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr) stmt, err := session.doPrepare(session.DB(), sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -186,11 +194,34 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
return session.DB().Exec(sqlStr, args...) return session.DB().Exec(sqlStr, args...)
} }
func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {
switch sqlorArgs[0].(type) {
case string:
return sqlorArgs[0].(string), sqlorArgs[1:], nil
case *builder.Builder:
return sqlorArgs[0].(*builder.Builder).ToSQL()
case builder.Builder:
bd := sqlorArgs[0].(builder.Builder)
return bd.ToSQL()
}
return "", nil, ErrUnSupportedType
}
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
if len(sqlorArgs) == 0 {
return nil, ErrUnSupportedType
}
sqlStr, args, err := convertSQLOrArgs(sqlorArgs...)
if err != nil {
return nil, err
}
return session.exec(sqlStr, args...) return session.exec(sqlStr, args...)
} }

View File

@ -6,9 +6,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"reflect"
"strings" "strings"
"github.com/go-xorm/core" "github.com/go-xorm/core"
@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error {
} }
func (session *Session) createTable(bean interface{}) error { func (session *Session) createTable(bean interface{}) error {
v := rValue(bean) if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error {
} }
func (session *Session) createIndexes(bean interface{}) error { func (session *Session) createIndexes(bean interface{}) error {
v := rValue(bean) if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error {
} }
func (session *Session) createUniques(bean interface{}) error { func (session *Session) createUniques(bean interface{}) error {
v := rValue(bean) if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error {
} }
func (session *Session) dropIndexes(bean interface{}) error { func (session *Session) dropIndexes(bean interface{}) error {
v := rValue(bean) if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
} }
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName) tableName := session.engine.TableName(beanOrTableName)
if err != nil {
return err
}
var needDrop = true var needDrop = true
if !session.engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
} }
if needDrop { if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(tableName) sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
_, err = session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
return nil return nil
@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close() defer session.Close()
} }
tableName, err := session.engine.tableName(beanOrTableName) tableName := session.engine.TableName(beanOrTableName)
if err != nil {
return false, err
}
return session.isTableExist(tableName) return session.isTableExist(tableName)
} }
@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
// IsTableEmpty if table have any records // IsTableEmpty if table have any records
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
v := rValue(bean) if session.isAutoClose {
t := v.Type() defer session.Close()
if t.Kind() == reflect.String {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean)
return rows == 0, err
} }
return false, errors.New("bean should be a struct or struct's point") return session.isTableEmpty(session.engine.TableName(bean))
} }
func (session *Session) isTableEmpty(tableName string) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64 var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
err := session.queryRow(sqlStr).Scan(&total) err := session.queryRow(sqlStr).Scan(&total)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -255,6 +233,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err return err
} }
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var structTables []*core.Table var structTables []*core.Table
for _, bean := range beans { for _, bean := range beans {
@ -264,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err return err
} }
structTables = append(structTables, table) structTables = append(structTables, table)
var tbName = session.tbNameNoSchema(table) tbName := engine.TableName(bean)
tbNameWithSchema := engine.TableName(tbName, true)
var oriTable *core.Table var oriTable *core.Table
for _, tb := range tables { for _, tb := range tables {
@ -309,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
if engine.dialect.DBType() == core.MYSQL || if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES { engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} else { } else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} }
} }
} else { } else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbName, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} }
} else if expectedType == core.Varchar { } else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} }
} }
} }
@ -348,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
} else { } else {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbName session.statement.tableName = tbNameWithSchema
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
} }
if err != nil { if err != nil {
@ -371,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil { if oriIndex != nil {
if oriIndex.Type != index.Type { if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex) sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -387,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2) sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -398,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbName session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbName, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbName session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbName, name) err = session.addIndex(tbNameWithSchema, name)
} }
if err != nil { if err != nil {
return err return err
@ -428,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil { if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
} }
} }
} }

View File

@ -24,6 +24,7 @@ func (session *Session) Rollback() error {
if !session.isAutoCommit && !session.isCommitedOrRollbacked { if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL(session.engine.dialect.RollBackStr()) session.saveLastSQL(session.engine.dialect.RollBackStr())
session.isCommitedOrRollbacked = true session.isCommitedOrRollbacked = true
session.isAutoCommit = true
return session.tx.Rollback() return session.tx.Rollback()
} }
return nil return nil
@ -34,6 +35,7 @@ func (session *Session) Commit() error {
if !session.isAutoCommit && !session.isCommitedOrRollbacked { if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL("COMMIT") session.saveLastSQL("COMMIT")
session.isCommitedOrRollbacked = true session.isCommitedOrRollbacked = true
session.isAutoCommit = true
var err error var err error
if err = session.tx.Commit(); err == nil { if err = session.tx.Commit(); err == nil {
// handle processors after tx committed // handle processors after tx committed

View File

@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
} }
} }
cacher := session.engine.getCacher2(table) cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil { if err != nil {
@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct var isStruct = t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.setRefValue(v); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return 0, err return 0, err
} }
@ -176,12 +176,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
if session.statement.ColumnStr == "" { if session.statement.ColumnStr == "" {
colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false, colNames, args = session.statement.buildUpdates(bean, false, false,
false, false, session.statement.allUseBool, session.statement.useAllCols, false, false, true)
session.statement.mustColumnMap, session.statement.nullableMap,
session.statement.columnMap, true, session.statement.unscoped)
} else { } else {
colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true) colNames, args, err = session.genUpdateColumns(bean)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -202,7 +200,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table := session.statement.RefTable table := session.statement.RefTable
if session.statement.UseAutoTime && table != nil && table.Updated != "" { if session.statement.UseAutoTime && table != nil && table.Updated != "" {
if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok { if !session.statement.columnMap.contain(table.Updated) &&
!session.statement.omitColumnMap.contain(table.Updated) {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn() col := table.UpdatedColumn()
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
@ -242,10 +241,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var autoCond builder.Cond var autoCond builder.Cond
if !session.statement.noAutoCondition && len(condiBean) > 0 { if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error if c, ok := condiBean[0].(map[string]interface{}); ok {
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) autoCond = builder.Eq(c)
if err != nil { } else {
return 0, err ct := reflect.TypeOf(condiBean[0])
k := ct.Kind()
if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k == reflect.Struct {
var err error
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
} else {
return 0, ErrConditionType
}
} }
} }
@ -349,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
if table != nil { if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { //session.cacheUpdate(table, tableName, sqlStr, args...)
//session.cacheUpdate(table, tableName, sqlStr, args...) session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.ClearBeans(tableName) cacher.ClearBeans(tableName)
}
} }
// handle after update processors // handle after update processors
@ -389,3 +400,92 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return res.RowsAffected() return res.RowsAffected()
} }
func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated {
continue
}
if len(session.statement.columnMap) > 0 {
if !session.statement.columnMap.contain(col.Name) {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
}
return colNames, args, nil
}

View File

@ -5,7 +5,6 @@
package xorm package xorm
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
@ -18,21 +17,6 @@ import (
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *core.Table RefTable *core.Table
@ -47,7 +31,6 @@ type Statement struct {
HavingStr string HavingStr string
ColumnStr string ColumnStr string
selectStr string selectStr string
columnMap map[string]bool
useAllCols bool useAllCols bool
OmitStr string OmitStr string
AltTableName string AltTableName string
@ -67,6 +50,8 @@ type Statement struct {
allUseBool bool allUseBool bool
checkVersion bool checkVersion bool
unscoped bool unscoped bool
columnMap columnMap
omitColumnMap columnMap
mustColumnMap map[string]bool mustColumnMap map[string]bool
nullableMap map[string]bool nullableMap map[string]bool
incrColumns map[string]incrParam incrColumns map[string]incrParam
@ -74,6 +59,7 @@ type Statement struct {
exprColumns map[string]exprParam exprColumns map[string]exprParam
cond builder.Cond cond builder.Cond
bufferSize int bufferSize int
context ContextCache
} }
// Init reset all the statement's fields // Init reset all the statement's fields
@ -89,7 +75,8 @@ func (statement *Statement) Init() {
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = "" statement.ColumnStr = ""
statement.OmitStr = "" statement.OmitStr = ""
statement.columnMap = make(map[string]bool) statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{}
statement.AltTableName = "" statement.AltTableName = ""
statement.tableName = "" statement.tableName = ""
statement.idParam = nil statement.idParam = nil
@ -113,6 +100,7 @@ func (statement *Statement) Init() {
statement.exprColumns = make(map[string]exprParam) statement.exprColumns = make(map[string]exprParam)
statement.cond = builder.NewCond() statement.cond = builder.NewCond()
statement.bufferSize = 0 statement.bufferSize = 0
statement.context = nil
} }
// NoAutoCondition if you do not want convert bean's field as query condition, then use this function // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
@ -160,6 +148,9 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
case string: case string:
cond := builder.Expr(query.(string), args...) cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.And(cond) statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{}))
statement.cond = statement.cond.And(cond)
case builder.Cond: case builder.Cond:
cond := query.(builder.Cond) cond := query.(builder.Cond)
statement.cond = statement.cond.And(cond) statement.cond = statement.cond.And(cond)
@ -181,6 +172,9 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
case string: case string:
cond := builder.Expr(query.(string), args...) cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.Or(cond) statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{}))
statement.cond = statement.cond.Or(cond)
case builder.Cond: case builder.Cond:
cond := query.(builder.Cond) cond := query.(builder.Cond)
statement.cond = statement.cond.Or(cond) statement.cond = statement.cond.Or(cond)
@ -215,34 +209,33 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil { if err != nil {
return err return err
} }
statement.tableName = statement.Engine.tbName(v) statement.tableName = statement.Engine.TableName(v, true)
return nil return nil
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) setRefBean(bean interface{}) error {
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { var err error
v := rValue(tableNameOrBean) statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
t := v.Type() if err != nil {
if t.Kind() == reflect.String { return err
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v)
} }
return statement statement.tableName = statement.Engine.TableName(bean, true)
return nil
} }
// Auto generating update columnes and values according a struct // Auto generating update columnes and values according a struct
func buildUpdates(engine *Engine, table *core.Table, bean interface{}, func (statement *Statement) buildUpdates(bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion, includeUpdated, includeNil,
includeAutoIncr bool, allUseBool bool, useAllCols bool, includeAutoIncr, update bool) ([]string, []interface{}) {
mustColumnMap map[string]bool, nullableMap map[string]bool, engine := statement.Engine
columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { table := statement.RefTable
allUseBool := statement.allUseBool
useAllCols := statement.useAllCols
mustColumnMap := statement.mustColumnMap
nullableMap := statement.nullableMap
columnMap := statement.columnMap
omitColumnMap := statement.omitColumnMap
unscoped := statement.unscoped
var colNames = make([]string, 0) var colNames = make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
@ -262,7 +255,14 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if col.IsDeleted && !unscoped { if col.IsDeleted && !unscoped {
continue continue
} }
if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use { if omitColumnMap.contain(col.Name) {
continue
}
if len(columnMap) > 0 && !columnMap.contain(col.Name) {
continue
}
if col.MapType == core.ONLYFROMDB {
continue continue
} }
@ -598,17 +598,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
} }
func (statement *Statement) colmap2NewColsWithQuote() []string { func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, 0, len(statement.columnMap)) newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
for col := range statement.columnMap { copy(newColumns, statement.columnMap)
fields := strings.Split(strings.TrimSpace(col), ".") for i := 0; i < len(statement.columnMap); i++ {
if len(fields) == 1 { newColumns[i] = statement.Engine.Quote(newColumns[i])
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
} }
return newColumns return newColumns
} }
@ -636,10 +629,11 @@ func (statement *Statement) Select(str string) *Statement {
func (statement *Statement) Cols(columns ...string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...) cols := col2NewCols(columns...)
for _, nc := range cols { for _, nc := range cols {
statement.columnMap[strings.ToLower(nc)] = true statement.columnMap.add(nc)
} }
newColumns := statement.colmap2NewColsWithQuote() newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement return statement
@ -674,7 +668,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement {
func (statement *Statement) Omit(columns ...string) { func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
statement.columnMap[strings.ToLower(nc)] = false statement.omitColumnMap = append(statement.omitColumnMap, nc)
} }
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
@ -713,10 +707,9 @@ func (statement *Statement) OrderBy(order string) *Statement {
// Desc generate `ORDER BY xx DESC` // Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement { func (statement *Statement) Desc(colNames ...string) *Statement {
var buf bytes.Buffer var buf builder.StringBuilder
fmt.Fprintf(&buf, statement.OrderStr)
if len(statement.OrderStr) > 0 { if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, ", ") fmt.Fprint(&buf, statement.OrderStr, ", ")
} }
newColNames := statement.col2NewColsWithQuote(colNames...) newColNames := statement.col2NewColsWithQuote(colNames...)
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
@ -726,10 +719,9 @@ func (statement *Statement) Desc(colNames ...string) *Statement {
// Asc provide asc order by query condition, the input parameters are columns. // Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement { func (statement *Statement) Asc(colNames ...string) *Statement {
var buf bytes.Buffer var buf builder.StringBuilder
fmt.Fprintf(&buf, statement.OrderStr)
if len(statement.OrderStr) > 0 { if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, ", ") fmt.Fprint(&buf, statement.OrderStr, ", ")
} }
newColNames := statement.col2NewColsWithQuote(colNames...) newColNames := statement.col2NewColsWithQuote(colNames...)
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
@ -737,48 +729,35 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement return statement
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
}
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
return statement
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf bytes.Buffer var buf builder.StringBuilder
if len(statement.JoinStr) > 0 { if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else { } else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
switch tablename.(type) { tbName := statement.Engine.TableName(tablename, true)
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.String {
table = f.(string)
} else if t.Kind() == reflect.Struct {
table = statement.Engine.tbName(v)
}
}
if l > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table))
}
default:
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
fmt.Fprintf(&buf, " ON %v", condition) fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...) statement.joinArgs = append(statement.joinArgs, args...)
return statement return statement
@ -803,18 +782,20 @@ func (statement *Statement) Unscoped() *Statement {
} }
func (statement *Statement) genColumnStr() string { func (statement *Statement) genColumnStr() string {
var buf bytes.Buffer
if statement.RefTable == nil { if statement.RefTable == nil {
return "" return ""
} }
var buf builder.StringBuilder
columns := statement.RefTable.Columns() columns := statement.RefTable.Columns()
for _, col := range columns { for _, col := range columns {
if statement.OmitStr != "" { if statement.omitColumnMap.contain(col.Name) {
if _, ok := getFlagForColumn(statement.columnMap, col); ok { continue
continue }
}
if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
continue
} }
if col.MapType == core.ONLYTODB { if col.MapType == core.ONLYTODB {
@ -825,10 +806,6 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ") buf.WriteString(", ")
} }
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
buf.WriteString("id() AS ")
}
if statement.JoinStr != "" { if statement.JoinStr != "" {
if statement.TableAlias != "" { if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias) buf.WriteString(statement.TableAlias)
@ -853,11 +830,13 @@ func (statement *Statement) genCreateTableSQL() string {
func (statement *Statement) genIndexSQL() []string { func (statement *Statement) genIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
quote := statement.Engine.Quote for _, index := range statement.RefTable.Indexes {
for idxName, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType { if index.Type == core.IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
quote(tbName), quote(strings.Join(index.Cols, quote(",")))) /*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
} }
@ -883,16 +862,18 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string { func (statement *Statement) genDelIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes { for idxName, index := range statement.RefTable.Indexes {
var rIdxName string var rIdxName string
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
rIdxName = uniqueName(tbName, idxName) rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
rIdxName = indexName(tbName, idxName) rIdxName = indexName(idxPrefixName, idxName)
} }
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
if statement.Engine.dialect.IndexOnTable() { if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
} }
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
@ -901,8 +882,12 @@ func (statement *Statement) genDelIndexSQL() []string {
func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := statement.Engine.Quote quote := statement.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()), sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
col.String(statement.Engine.dialect)) col.String(statement.Engine.dialect))
if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ";"
return sql, []interface{}{} return sql, []interface{}{}
} }
@ -939,7 +924,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
v := rValue(bean) v := rValue(bean)
isStruct := v.Kind() == reflect.Struct isStruct := v.Kind() == reflect.Struct
if isStruct { if isStruct {
statement.setRefValue(v) statement.setRefBean(bean)
} }
var columnStr = statement.ColumnStr var columnStr = statement.ColumnStr
@ -950,7 +935,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if len(statement.JoinStr) == 0 { if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
} else { } else {
columnStr = statement.genColumnStr() columnStr = statement.genColumnStr()
} }
@ -958,7 +943,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
} else { } else {
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
} }
} }
} }
@ -972,13 +957,17 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if err := statement.mergeConds(bean); err != nil { if err := statement.mergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
} else {
if err := statement.processIDParam(); err != nil {
return "", nil, err
}
} }
condSQL, condArgs, err := builder.ToSQL(statement.cond) condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
sqlStr, err := statement.genSelectSQL(columnStr, condSQL) sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -991,7 +980,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var condArgs []interface{} var condArgs []interface{}
var err error var err error
if len(beans) > 0 { if len(beans) > 0 {
statement.setRefValue(rValue(beans[0])) statement.setRefBean(beans[0])
condSQL, condArgs, err = statement.genConds(beans[0]) condSQL, condArgs, err = statement.genConds(beans[0])
} else { } else {
condSQL, condArgs, err = builder.ToSQL(statement.cond) condSQL, condArgs, err = builder.ToSQL(statement.cond)
@ -1008,7 +997,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)" selectSQL = "count(*)"
} }
} }
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1017,7 +1006,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
} }
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefValue(rValue(bean)) statement.setRefBean(bean)
var sumStrs = make([]string, 0, len(columns)) var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns { for _, colName := range columns {
@ -1033,7 +1022,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1041,27 +1030,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var distinct string var (
distinct string
dialect = statement.Engine.Dialect()
quote = statement.Engine.Quote
fromStr = " FROM "
top, mssqlCondi, whereStr string
)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
var dialect = statement.Engine.Dialect()
var quote = statement.Engine.Quote
var top string
var mssqlCondi string
if err := statement.processIDParam(); err != nil {
return "", err
}
var buf bytes.Buffer
if len(condSQL) > 0 { if len(condSQL) > 0 {
fmt.Fprintf(&buf, " WHERE %v", condSQL) whereStr = " WHERE " + condSQL
} }
var whereStr = buf.String()
var fromStr = " FROM "
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName() fromStr += statement.TableName()
@ -1108,9 +1090,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
} }
var orderStr string var orderStr string
if len(statement.OrderStr) > 0 { if needOrderBy && len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr orderStr = " ORDER BY " + statement.OrderStr
} }
var groupStr string var groupStr string
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr groupStr = " GROUP BY " + statement.GroupByStr
@ -1120,45 +1103,50 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
} }
} }
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern var buf builder.StringBuilder
a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if len(mssqlCondi) > 0 { if len(mssqlCondi) > 0 {
if len(whereStr) > 0 { if len(whereStr) > 0 {
a += " AND " + mssqlCondi fmt.Fprint(&buf, " AND ", mssqlCondi)
} else { } else {
a += " WHERE " + mssqlCondi fmt.Fprint(&buf, " WHERE ", mssqlCondi)
} }
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
} }
if statement.HavingStr != "" { if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr) fmt.Fprint(&buf, " ", statement.HavingStr)
} }
if statement.OrderStr != "" { if needOrderBy && statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
} }
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if needLimit {
if statement.Start > 0 { if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) if statement.Start > 0 {
} else if statement.LimitN > 0 { fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } else if statement.LimitN > 0 {
} fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
} else if dialect.DBType() == core.ORACLE { }
if statement.Start != 0 || statement.LimitN != 0 { } else if dialect.DBType() == core.ORACLE {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) if statement.Start != 0 || statement.LimitN != 0 {
oldString := buf.String()
buf.Reset()
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
}
} }
} }
if statement.IsForUpdate { if statement.IsForUpdate {
a = dialect.ForUpdateSql(a) return dialect.ForUpdateSql(buf.String()), nil
} }
return return buf.String(), nil
} }
func (statement *Statement) processIDParam() error { func (statement *Statement) processIDParam() error {
if statement.idParam == nil { if statement.idParam == nil || statement.RefTable == nil {
return nil return nil
} }

View File

@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build go1.8
package xorm package xorm
import ( import (
@ -17,7 +19,7 @@ import (
const ( const (
// Version show the xorm's version // Version show the xorm's version
Version string = "0.6.4.0910" Version string = "0.7.0.0504"
) )
func regDrvsNDialects() bool { func regDrvsNDialects() bool {
@ -31,7 +33,7 @@ func regDrvsNDialects() bool {
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, "pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
@ -90,6 +92,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
TagIdentifier: "xorm", TagIdentifier: "xorm",
TZLocation: time.Local, TZLocation: time.Local,
tagHandlers: defaultTagHandlers, tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher),
} }
if uri.DbType == core.SQLITE { if uri.DbType == core.SQLITE {
@ -108,6 +111,13 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return engine, nil return engine, nil
} }
// NewEngineWithParams new a db manager with params. The params will be passed to dialect.
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) {
engine, err := NewEngine(driverName, dataSourceName)
engine.dialect.SetParams(params)
return engine, err
}
// Clone clone an engine // Clone clone an engine
func (engine *Engine) Clone() (*Engine, error) { func (engine *Engine) Clone() (*Engine, error) {
return NewEngine(engine.DriverName(), engine.DataSourceName()) return NewEngine(engine.DriverName(), engine.DataSourceName())

9
vendor/modules.txt vendored
View File

@ -49,12 +49,11 @@ github.com/go-redis/redis/internal/singleflight
github.com/go-redis/redis/internal/util github.com/go-redis/redis/internal/util
# github.com/go-sql-driver/mysql v1.4.1 # github.com/go-sql-driver/mysql v1.4.1
github.com/go-sql-driver/mysql github.com/go-sql-driver/mysql
# github.com/go-xorm/builder v0.0.0-20170519032130-c8871c857d25 # github.com/go-xorm/builder v0.3.2
github.com/go-xorm/builder github.com/go-xorm/builder
# github.com/go-xorm/core v0.5.8 # github.com/go-xorm/core v0.6.0
github.com/go-xorm/core github.com/go-xorm/core
# github.com/go-xorm/xorm v0.0.0-20170930012613-29d4a0330a00 # github.com/go-xorm/xorm v0.7.1
github.com/go-xorm/xorm/migrate
github.com/go-xorm/xorm github.com/go-xorm/xorm
# github.com/go-xorm/xorm-redis-cache v0.0.0-20180727005610-859b313566b2 # github.com/go-xorm/xorm-redis-cache v0.0.0-20180727005610-859b313566b2
github.com/go-xorm/xorm-redis-cache github.com/go-xorm/xorm-redis-cache
@ -217,3 +216,5 @@ honnef.co/go/tools/ssautil
honnef.co/go/tools/staticcheck/vrp honnef.co/go/tools/staticcheck/vrp
honnef.co/go/tools/callgraph honnef.co/go/tools/callgraph
honnef.co/go/tools/callgraph/static honnef.co/go/tools/callgraph/static
# src.techknowlogick.com/xormigrate v0.0.0-20190321151057-24497c23c09c
src.techknowlogick.com/xormigrate