Refactor sqlite

This commit is contained in:
Jinzhu 2018-03-18 20:54:39 +08:00
parent ee50dfa818
commit 7065ea8c97
7 changed files with 52 additions and 29 deletions

8
api.go
View File

@ -148,8 +148,9 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
// has inline condition
if len(where) > 0 {
clone := tx.clone()
stmt = s.Statement.Clone()
stmt = tx.Statement.Clone()
stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...))
clone.Statement = stmt
tx.AddError(clone.Dialect().Query(clone))
tx.AddError(clone.Error)
} else {
@ -214,8 +215,9 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
// has inline condition
if len(where) > 0 {
clone := tx.clone()
stmt = s.Statement.Clone()
stmt = tx.Statement.Clone()
stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...))
clone.Statement = stmt
tx.AddError(clone.Dialect().Update(clone))
tx.AddError(clone.Error)
} else {
@ -247,7 +249,7 @@ func (s *DB) Table(name string) *DB {
func (s *DB) AddError(err error) {
if err != nil {
if err != ErrRecordNotFound {
s.Config.Logger.Error(err)
s.Config.Logger.Error(err.Error())
}
if errs := s.GetErrors(); len(errs) == 0 {

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/logger"
// import sqlite3 driver
_ "github.com/mattn/go-sqlite3"
)
@ -16,6 +17,7 @@ func Open(dsn string, config Config) (*gorm.DB, error) {
dialect, err := New(dsn)
config.Dialect = dialect
gormConfig := gorm.Config(config)
gormConfig.Logger = logger.DefaultLogger
return &gorm.DB{Config: &gormConfig}, err
}

View File

@ -4,9 +4,9 @@ import (
"fmt"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/jinzhu/gorm"
)
@ -31,11 +31,18 @@ func TestBatchInsert(t *testing.T) {
DB.Create(users)
spew.Dump(users)
for _, user := range users {
if user.ID == 0 {
t.Errorf("User should have primary key")
}
var newUser User
if err := DB.Find(&newUser, "id = ?", user.ID).Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(&newUser, user) {
t.Errorf("User should be equal, but got %#v, should be %#v", newUser, user)
}
}
}

View File

@ -68,7 +68,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
}
valueBuffer.WriteString("?")
if field.IsBlank {
if (field.Field.IsPrimaryKey || field.HasDefaultValue) && field.IsBlank {
args = append(args, nil)
} else {
args = append(args, field.Value.Interface())
@ -136,7 +136,7 @@ func (dialect *Dialect) Query(tx *gorm.DB) (err error) {
// Join SQL
if builder := <-joinChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
if len(tx.Statement.Conditions) > 0 {
@ -147,20 +147,20 @@ func (dialect *Dialect) Query(tx *gorm.DB) (err error) {
if builder := <-groupChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
if builder := <-orderChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
if builder := <-limitChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
rows, err := dialect.DB.Query(s.String(), args)
rows, err := dialect.DB.Query(s.String(), args...)
if err == nil {
err = scanRows(rows, tx.Statement.Dest)
@ -178,10 +178,7 @@ func scanRows(rows *sql.Rows, values interface{}) (err error) {
if kind := results.Kind(); kind == reflect.Slice {
isSlice = true
resultType := results.Type().Elem()
results.Set(reflect.MakeSlice(resultType, 0, 0))
} else if kind != reflect.Struct || kind != reflect.Map {
return errors.New("unsupported destination, should be slice or map or struct")
results.Set(reflect.MakeSlice(results.Type().Elem(), 0, 0))
}
for rows.Next() {
@ -209,6 +206,7 @@ func scanRows(rows *sql.Rows, values interface{}) (err error) {
}
func toScanMap(columns []string, elem reflect.Value) (results []interface{}, err error) {
var ignored interface{}
results = make([]interface{}, len(columns))
switch elem.Kind() {
@ -219,10 +217,12 @@ func toScanMap(columns []string, elem reflect.Value) (results []interface{}, err
results[idx] = &value
}
case reflect.Struct:
fieldsMap := model.Parse(elem.Interface()).FieldsMap()
fieldsMap := model.Parse(elem.Addr().Interface()).FieldsMap()
for idx, column := range columns {
if f, ok := fieldsMap[column]; ok {
results[idx] = f.Value.Addr().Interface()
results[idx] = f.Value.Interface()
} else {
results[idx] = &ignored
}
}
case reflect.Ptr:
@ -276,13 +276,15 @@ func (dialect *Dialect) Update(tx *gorm.DB) (err error) {
if builder := <-orderChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
if builder := <-limitChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
_, err = dialect.DB.Exec(s.String(), args...)
return err
}
@ -306,13 +308,14 @@ func (dialect *Dialect) Delete(tx *gorm.DB) (err error) {
if builder := <-orderChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
if builder := <-limitChan; builder != nil {
_, err = builder.SQL.WriteTo(s)
args = append(args, builder.Args)
args = append(args, builder.Args...)
}
_, err = dialect.DB.Exec(s.String(), args...)
return
}

View File

@ -9,9 +9,9 @@ import (
// Interface logger interface
type Interface interface {
SQL(data ...interface{})
Info(data ...interface{})
Warn(data ...interface{})
Error(data ...interface{})
Info(msg string, data ...interface{})
Warn(msg string, data ...interface{})
Error(msg string, data ...interface{})
}
// LogLevel log level

View File

@ -19,7 +19,7 @@ var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string
// Parse parse model
func Parse(value interface{}) *Model {
return &Model{
ReflectValue: reflect.ValueOf(value),
ReflectValue: reflect.Indirect(reflect.ValueOf(value)),
Schema: schema.Parse(value),
}
}
@ -39,9 +39,16 @@ func (model *Model) FieldsMap() map[string]*Field {
for _, bn := range sf.BindNames {
obj = obj.FieldByName(bn)
}
field := &Field{Field: sf, Value: obj}
fieldsMap[sf.DBName] = field
if obj.Kind() == reflect.Ptr {
if obj.IsNil() {
obj.Set(reflect.New(obj.Type().Elem()))
}
fieldsMap[sf.DBName] = &Field{Field: sf, Value: obj.Addr()}
} else {
fieldsMap[sf.DBName] = &Field{Field: sf, Value: obj.Addr()}
}
}
return fieldsMap

View File

@ -1,6 +1,8 @@
package gorm
import "sync"
import (
"sync"
)
// Column column type
type Column = string