diff --git a/api.go b/api.go index 6be438bf..faab1892 100644 --- a/api.go +++ b/api.go @@ -148,8 +148,9 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { // has inline condition if len(where) > 0 { clone := tx.clone() - stmt = s.Statement.Clone() + stmt = tx.Statement.Clone() stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...)) + clone.Statement = stmt tx.AddError(clone.Dialect().Query(clone)) tx.AddError(clone.Error) } else { @@ -214,8 +215,9 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB { // has inline condition if len(where) > 0 { clone := tx.clone() - stmt = s.Statement.Clone() + stmt = tx.Statement.Clone() stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...)) + clone.Statement = stmt tx.AddError(clone.Dialect().Update(clone)) tx.AddError(clone.Error) } else { @@ -247,7 +249,7 @@ func (s *DB) Table(name string) *DB { func (s *DB) AddError(err error) { if err != nil { if err != ErrRecordNotFound { - s.Config.Logger.Error(err) + s.Config.Logger.Error(err.Error()) } if errs := s.GetErrors(); len(errs) == 0 { diff --git a/dialects/sqlite/main.go b/dialects/sqlite/main.go index 49829f97..e5179b0e 100644 --- a/dialects/sqlite/main.go +++ b/dialects/sqlite/main.go @@ -4,6 +4,7 @@ import ( "database/sql" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/logger" // import sqlite3 driver _ "github.com/mattn/go-sqlite3" ) @@ -16,6 +17,7 @@ func Open(dsn string, config Config) (*gorm.DB, error) { dialect, err := New(dsn) config.Dialect = dialect gormConfig := gorm.Config(config) + gormConfig.Logger = logger.DefaultLogger return &gorm.DB{Config: &gormConfig}, err } diff --git a/dialects/sqlite/main_test.go b/dialects/sqlite/main_test.go index 9039f5fc..7a87099b 100644 --- a/dialects/sqlite/main_test.go +++ b/dialects/sqlite/main_test.go @@ -4,9 +4,9 @@ import ( "fmt" "os" "path/filepath" + "reflect" "testing" - "github.com/davecgh/go-spew/spew" "github.com/jinzhu/gorm" ) @@ -31,11 +31,18 @@ func TestBatchInsert(t *testing.T) { DB.Create(users) - spew.Dump(users) - for _, user := range users { if user.ID == 0 { t.Errorf("User should have primary key") } + + var newUser User + if err := DB.Find(&newUser, "id = ?", user.ID).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(&newUser, user) { + t.Errorf("User should be equal, but got %#v, should be %#v", newUser, user) + } } } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index abde9e64..6d41db8c 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -68,7 +68,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { } valueBuffer.WriteString("?") - if field.IsBlank { + if (field.Field.IsPrimaryKey || field.HasDefaultValue) && field.IsBlank { args = append(args, nil) } else { args = append(args, field.Value.Interface()) @@ -136,7 +136,7 @@ func (dialect *Dialect) Query(tx *gorm.DB) (err error) { // Join SQL if builder := <-joinChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } if len(tx.Statement.Conditions) > 0 { @@ -147,20 +147,20 @@ func (dialect *Dialect) Query(tx *gorm.DB) (err error) { if builder := <-groupChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } if builder := <-orderChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } if builder := <-limitChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } - rows, err := dialect.DB.Query(s.String(), args) + rows, err := dialect.DB.Query(s.String(), args...) if err == nil { err = scanRows(rows, tx.Statement.Dest) @@ -178,10 +178,7 @@ func scanRows(rows *sql.Rows, values interface{}) (err error) { if kind := results.Kind(); kind == reflect.Slice { isSlice = true - resultType := results.Type().Elem() - results.Set(reflect.MakeSlice(resultType, 0, 0)) - } else if kind != reflect.Struct || kind != reflect.Map { - return errors.New("unsupported destination, should be slice or map or struct") + results.Set(reflect.MakeSlice(results.Type().Elem(), 0, 0)) } for rows.Next() { @@ -209,6 +206,7 @@ func scanRows(rows *sql.Rows, values interface{}) (err error) { } func toScanMap(columns []string, elem reflect.Value) (results []interface{}, err error) { + var ignored interface{} results = make([]interface{}, len(columns)) switch elem.Kind() { @@ -219,10 +217,12 @@ func toScanMap(columns []string, elem reflect.Value) (results []interface{}, err results[idx] = &value } case reflect.Struct: - fieldsMap := model.Parse(elem.Interface()).FieldsMap() + fieldsMap := model.Parse(elem.Addr().Interface()).FieldsMap() for idx, column := range columns { if f, ok := fieldsMap[column]; ok { - results[idx] = f.Value.Addr().Interface() + results[idx] = f.Value.Interface() + } else { + results[idx] = &ignored } } case reflect.Ptr: @@ -276,13 +276,15 @@ func (dialect *Dialect) Update(tx *gorm.DB) (err error) { if builder := <-orderChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } if builder := <-limitChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } + + _, err = dialect.DB.Exec(s.String(), args...) return err } @@ -306,13 +308,14 @@ func (dialect *Dialect) Delete(tx *gorm.DB) (err error) { if builder := <-orderChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } if builder := <-limitChan; builder != nil { _, err = builder.SQL.WriteTo(s) - args = append(args, builder.Args) + args = append(args, builder.Args...) } + _, err = dialect.DB.Exec(s.String(), args...) return } diff --git a/logger/logger.go b/logger/logger.go index 113c773b..c7098103 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -9,9 +9,9 @@ import ( // Interface logger interface type Interface interface { SQL(data ...interface{}) - Info(data ...interface{}) - Warn(data ...interface{}) - Error(data ...interface{}) + Info(msg string, data ...interface{}) + Warn(msg string, data ...interface{}) + Error(msg string, data ...interface{}) } // LogLevel log level diff --git a/model/model.go b/model/model.go index 36808b84..33bf2163 100644 --- a/model/model.go +++ b/model/model.go @@ -19,7 +19,7 @@ var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string // Parse parse model func Parse(value interface{}) *Model { return &Model{ - ReflectValue: reflect.ValueOf(value), + ReflectValue: reflect.Indirect(reflect.ValueOf(value)), Schema: schema.Parse(value), } } @@ -39,9 +39,16 @@ func (model *Model) FieldsMap() map[string]*Field { for _, bn := range sf.BindNames { obj = obj.FieldByName(bn) } - field := &Field{Field: sf, Value: obj} - fieldsMap[sf.DBName] = field + if obj.Kind() == reflect.Ptr { + if obj.IsNil() { + obj.Set(reflect.New(obj.Type().Elem())) + } + fieldsMap[sf.DBName] = &Field{Field: sf, Value: obj.Addr()} + } else { + fieldsMap[sf.DBName] = &Field{Field: sf, Value: obj.Addr()} + } + } return fieldsMap diff --git a/statement.go b/statement.go index e09a28e1..4c227304 100644 --- a/statement.go +++ b/statement.go @@ -1,6 +1,8 @@ package gorm -import "sync" +import ( + "sync" +) // Column column type type Column = string