Merge remote-tracking branch 'upstream/master' into jay/soft-delete-alive

This commit is contained in:
Jay Taylor 2015-06-11 13:54:56 -07:00
commit d65e8b3546
9 changed files with 64 additions and 35 deletions

View File

@ -70,18 +70,24 @@ func Create(scope *Scope) {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
if primaryField != nil {
if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id))
}
}
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
}
}
}

View File

@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)

View File

@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
}
}
}
func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) {
s.NewScope(source).GetModelStruct().TableName = handler
}

View File

@ -61,6 +61,19 @@ func init() {
runMigration()
}
func TestStringPrimaryKey(t *testing.T) {
type UUIDStruct struct {
ID string `gorm:"primary_key"`
Name string
}
DB.AutoMigrate(&UUIDStruct{})
data := UUIDStruct{ID: "uuid", Name: "hello"}
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
t.Errorf("string primary key should not be populated")
}
}
func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {

View File

@ -13,11 +13,19 @@ import (
var modelStructs = map[reflect.Type]*ModelStruct{}
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
TableName func(*DB) string
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
defaultTableName string
}
func (s ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName)
}
type StructField struct {
@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
// Set tablename
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
if name, ok := results[0].Interface().(string); ok {
modelStruct.TableName = func(*DB) string {
return name
}
}
}
type tabler interface {
TableName() string
}
if tabler, ok := reflect.New(scopeType).Interface().(interface {
TableName() string
}); ok {
modelStruct.defaultTableName = tabler.TableName()
} else {
name := ToDBName(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable {
@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}
modelStruct.TableName = func(*DB) string {
return name
}
modelStruct.defaultTableName = name
}
// Get all fields

View File

@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Int64, reflect.Uint64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Uint64:
if autoIncrease {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:

View File

@ -216,13 +216,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope {
}
fieldStruct, _ := modelType.FieldByName(column)
var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice {
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
} else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
}
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ {
columns = reflect.Append(columns, column.Index(i).Addr())

View File

@ -251,12 +251,7 @@ func (scope *Scope) TableName() string {
return tabler.TableName(scope.db)
}
if scope.GetModelStruct().TableName != nil {
return scope.GetModelStruct().TableName(scope.db)
}
scope.Err(errors.New("wrong table name"))
return ""
return scope.GetModelStruct().TableName(scope.db)
}
func (scope *Scope) QuotedTableName() (name string) {
@ -278,7 +273,7 @@ func (scope *Scope) CombinedConditionSql() string {
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
for _, field := range scope.Fields() {
if field.Name == name {
if field.Name == name || field.DBName == name {
return field, true
}
}

View File

@ -480,7 +480,7 @@ func (scope *Scope) createTable() *Scope {
}
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, field.DBName)
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
}