Merge remote-tracking branch 'upstream/master' into jay/soft-delete-alive
This commit is contained in:
commit
d65e8b3546
@ -70,18 +70,24 @@ func Create(scope *Scope) {
|
||||
id, err := result.LastInsertId()
|
||||
if scope.Err(err) == nil {
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
if primaryField != nil {
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
scope.Err(scope.SetColumn(primaryField, id))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
|
||||
} else {
|
||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri
|
||||
|
||||
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
|
||||
|
||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
|
||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
|
||||
|
||||
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
|
||||
|
||||
|
4
main.go
4
main.go
@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) {
|
||||
s.NewScope(source).GetModelStruct().TableName = handler
|
||||
}
|
||||
|
13
main_test.go
13
main_test.go
@ -61,6 +61,19 @@ func init() {
|
||||
runMigration()
|
||||
}
|
||||
|
||||
func TestStringPrimaryKey(t *testing.T) {
|
||||
type UUIDStruct struct {
|
||||
ID string `gorm:"primary_key"`
|
||||
Name string
|
||||
}
|
||||
DB.AutoMigrate(&UUIDStruct{})
|
||||
|
||||
data := UUIDStruct{ID: "uuid", Name: "hello"}
|
||||
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
|
||||
t.Errorf("string primary key should not be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||
var columns []string
|
||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
|
@ -13,11 +13,19 @@ import (
|
||||
|
||||
var modelStructs = map[reflect.Type]*ModelStruct{}
|
||||
|
||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||
return defaultTableName
|
||||
}
|
||||
|
||||
type ModelStruct struct {
|
||||
PrimaryFields []*StructField
|
||||
StructFields []*StructField
|
||||
ModelType reflect.Type
|
||||
TableName func(*DB) string
|
||||
defaultTableName string
|
||||
}
|
||||
|
||||
func (s ModelStruct) TableName(db *DB) string {
|
||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||
}
|
||||
|
||||
type StructField struct {
|
||||
@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
|
||||
// Set tablename
|
||||
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
|
||||
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
|
||||
if name, ok := results[0].Interface().(string); ok {
|
||||
modelStruct.TableName = func(*DB) string {
|
||||
return name
|
||||
}
|
||||
}
|
||||
type tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
if tabler, ok := reflect.New(scopeType).Interface().(interface {
|
||||
TableName() string
|
||||
}); ok {
|
||||
modelStruct.defaultTableName = tabler.TableName()
|
||||
} else {
|
||||
name := ToDBName(scopeType.Name())
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
}
|
||||
|
||||
modelStruct.TableName = func(*DB) string {
|
||||
return name
|
||||
}
|
||||
modelStruct.defaultTableName = name
|
||||
}
|
||||
|
||||
// Get all fields
|
||||
|
14
mysql.go
14
mysql.go
@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||
if autoIncrease {
|
||||
return "int AUTO_INCREMENT"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "int unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "int unsigned"
|
||||
case reflect.Int64:
|
||||
if autoIncrease {
|
||||
return "bigint AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigint unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint unsigned"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
case reflect.String:
|
||||
|
@ -216,13 +216,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
||||
}
|
||||
fieldStruct, _ := modelType.FieldByName(column)
|
||||
var columns reflect.Value
|
||||
if fieldStruct.Type.Kind() == reflect.Slice {
|
||||
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
||||
} else {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
||||
}
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
||||
if column.Kind() == reflect.Ptr {
|
||||
column = column.Elem()
|
||||
}
|
||||
if column.Kind() == reflect.Slice {
|
||||
for i := 0; i < column.Len(); i++ {
|
||||
columns = reflect.Append(columns, column.Index(i).Addr())
|
||||
|
7
scope.go
7
scope.go
@ -251,12 +251,7 @@ func (scope *Scope) TableName() string {
|
||||
return tabler.TableName(scope.db)
|
||||
}
|
||||
|
||||
if scope.GetModelStruct().TableName != nil {
|
||||
return scope.GetModelStruct().TableName(scope.db)
|
||||
}
|
||||
|
||||
scope.Err(errors.New("wrong table name"))
|
||||
return ""
|
||||
}
|
||||
|
||||
func (scope *Scope) QuotedTableName() (name string) {
|
||||
@ -278,7 +273,7 @@ func (scope *Scope) CombinedConditionSql() string {
|
||||
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == name {
|
||||
if field.Name == name || field.DBName == name {
|
||||
return field, true
|
||||
}
|
||||
}
|
||||
|
@ -480,7 +480,7 @@ func (scope *Scope) createTable() *Scope {
|
||||
}
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
primaryKeys = append(primaryKeys, field.DBName)
|
||||
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user