Merge remote-tracking branch 'upstream/master' into jay/soft-delete-alive

This commit is contained in:
Jay Taylor 2015-06-11 13:54:56 -07:00
commit d65e8b3546
9 changed files with 64 additions and 35 deletions

View File

@ -70,18 +70,24 @@ func Create(scope *Scope) {
id, err := result.LastInsertId() id, err := result.LastInsertId()
if scope.Err(err) == nil { if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
if primaryField != nil { if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id)) scope.Err(scope.SetColumn(primaryField, id))
} }
} }
} }
} else { } else {
if primaryField == nil { if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil { if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected() scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1
} else {
scope.Err(err)
} }
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
} }
} }
} }

View File

@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) [Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)

View File

@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
} }
} }
} }
func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) {
s.NewScope(source).GetModelStruct().TableName = handler
}

View File

@ -61,6 +61,19 @@ func init() {
runMigration() runMigration()
} }
func TestStringPrimaryKey(t *testing.T) {
type UUIDStruct struct {
ID string `gorm:"primary_key"`
Name string
}
DB.AutoMigrate(&UUIDStruct{})
data := UUIDStruct{ID: "uuid", Name: "hello"}
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
t.Errorf("string primary key should not be populated")
}
}
func TestExceptionsWithInvalidSql(t *testing.T) { func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {

View File

@ -13,11 +13,19 @@ import (
var modelStructs = map[reflect.Type]*ModelStruct{} var modelStructs = map[reflect.Type]*ModelStruct{}
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
type ModelStruct struct { type ModelStruct struct {
PrimaryFields []*StructField PrimaryFields []*StructField
StructFields []*StructField StructFields []*StructField
ModelType reflect.Type ModelType reflect.Type
TableName func(*DB) string defaultTableName string
}
func (s ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName)
} }
type StructField struct { type StructField struct {
@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
// Set tablename // Set tablename
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { type tabler interface {
if results := fm.Call([]reflect.Value{}); len(results) > 0 { TableName() string
if name, ok := results[0].Interface().(string); ok { }
modelStruct.TableName = func(*DB) string {
return name if tabler, ok := reflect.New(scopeType).Interface().(interface {
} TableName() string
} }); ok {
} modelStruct.defaultTableName = tabler.TableName()
} else { } else {
name := ToDBName(scopeType.Name()) name := ToDBName(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable { if scope.db == nil || !scope.db.parent.singularTable {
@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
modelStruct.TableName = func(*DB) string { modelStruct.defaultTableName = name
return name
}
} }
// Get all fields // Get all fields

View File

@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease { if autoIncrease {
return "int AUTO_INCREMENT" return "int AUTO_INCREMENT"
} }
return "int" return "int"
case reflect.Int64, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease { if autoIncrease {
return "bigint AUTO_INCREMENT" return "bigint AUTO_INCREMENT"
} }
return "bigint" return "bigint"
case reflect.Uint64:
if autoIncrease {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return "double" return "double"
case reflect.String: case reflect.String:

View File

@ -216,13 +216,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope {
} }
fieldStruct, _ := modelType.FieldByName(column) fieldStruct, _ := modelType.FieldByName(column)
var columns reflect.Value var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice { if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
} else { } else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
} }
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column) column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
}
if column.Kind() == reflect.Slice { if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ { for i := 0; i < column.Len(); i++ {
columns = reflect.Append(columns, column.Index(i).Addr()) columns = reflect.Append(columns, column.Index(i).Addr())

View File

@ -251,12 +251,7 @@ func (scope *Scope) TableName() string {
return tabler.TableName(scope.db) return tabler.TableName(scope.db)
} }
if scope.GetModelStruct().TableName != nil { return scope.GetModelStruct().TableName(scope.db)
return scope.GetModelStruct().TableName(scope.db)
}
scope.Err(errors.New("wrong table name"))
return ""
} }
func (scope *Scope) QuotedTableName() (name string) { func (scope *Scope) QuotedTableName() (name string) {
@ -278,7 +273,7 @@ func (scope *Scope) CombinedConditionSql() string {
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if field.Name == name { if field.Name == name || field.DBName == name {
return field, true return field, true
} }
} }

View File

@ -480,7 +480,7 @@ func (scope *Scope) createTable() *Scope {
} }
if field.IsPrimaryKey { if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, field.DBName) primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
} }
scope.createJoinTable(field) scope.createJoinTable(field)
} }