Merge remote-tracking branch 'upstream/master' into jay/soft-delete-alive
This commit is contained in:
commit
d65e8b3546
@ -70,18 +70,24 @@ func Create(scope *Scope) {
|
|||||||
id, err := result.LastInsertId()
|
id, err := result.LastInsertId()
|
||||||
if scope.Err(err) == nil {
|
if scope.Err(err) == nil {
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
if primaryField != nil {
|
if primaryField != nil && primaryField.IsBlank {
|
||||||
scope.Err(scope.SetColumn(primaryField, id))
|
scope.Err(scope.SetColumn(primaryField, id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if primaryField == nil {
|
if primaryField == nil {
|
||||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
|
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||||
|
} else {
|
||||||
|
scope.Err(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||||
|
scope.db.RowsAffected = 1
|
||||||
|
} else {
|
||||||
|
scope.Err(err)
|
||||||
}
|
}
|
||||||
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
|
|
||||||
scope.db.RowsAffected = 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri
|
|||||||
|
|
||||||
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
|
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
|
||||||
|
|
||||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
|
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
|
||||||
|
|
||||||
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
|
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
|
||||||
|
|
||||||
|
4
main.go
4
main.go
@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) {
|
|
||||||
s.NewScope(source).GetModelStruct().TableName = handler
|
|
||||||
}
|
|
||||||
|
13
main_test.go
13
main_test.go
@ -61,6 +61,19 @@ func init() {
|
|||||||
runMigration()
|
runMigration()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStringPrimaryKey(t *testing.T) {
|
||||||
|
type UUIDStruct struct {
|
||||||
|
ID string `gorm:"primary_key"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&UUIDStruct{})
|
||||||
|
|
||||||
|
data := UUIDStruct{ID: "uuid", Name: "hello"}
|
||||||
|
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
|
||||||
|
t.Errorf("string primary key should not be populated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExceptionsWithInvalidSql(t *testing.T) {
|
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||||
var columns []string
|
var columns []string
|
||||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||||
|
@ -13,11 +13,19 @@ import (
|
|||||||
|
|
||||||
var modelStructs = map[reflect.Type]*ModelStruct{}
|
var modelStructs = map[reflect.Type]*ModelStruct{}
|
||||||
|
|
||||||
|
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||||
|
return defaultTableName
|
||||||
|
}
|
||||||
|
|
||||||
type ModelStruct struct {
|
type ModelStruct struct {
|
||||||
PrimaryFields []*StructField
|
PrimaryFields []*StructField
|
||||||
StructFields []*StructField
|
StructFields []*StructField
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
TableName func(*DB) string
|
defaultTableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ModelStruct) TableName(db *DB) string {
|
||||||
|
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StructField struct {
|
type StructField struct {
|
||||||
@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set tablename
|
// Set tablename
|
||||||
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
|
type tabler interface {
|
||||||
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
|
TableName() string
|
||||||
if name, ok := results[0].Interface().(string); ok {
|
}
|
||||||
modelStruct.TableName = func(*DB) string {
|
|
||||||
return name
|
if tabler, ok := reflect.New(scopeType).Interface().(interface {
|
||||||
}
|
TableName() string
|
||||||
}
|
}); ok {
|
||||||
}
|
modelStruct.defaultTableName = tabler.TableName()
|
||||||
} else {
|
} else {
|
||||||
name := ToDBName(scopeType.Name())
|
name := ToDBName(scopeType.Name())
|
||||||
if scope.db == nil || !scope.db.parent.singularTable {
|
if scope.db == nil || !scope.db.parent.singularTable {
|
||||||
@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStruct.TableName = func(*DB) string {
|
modelStruct.defaultTableName = name
|
||||||
return name
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all fields
|
// Get all fields
|
||||||
|
14
mysql.go
14
mysql.go
@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||||
if autoIncrease {
|
if autoIncrease {
|
||||||
return "int AUTO_INCREMENT"
|
return "int AUTO_INCREMENT"
|
||||||
}
|
}
|
||||||
return "int"
|
return "int"
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if autoIncrease {
|
||||||
|
return "int unsigned AUTO_INCREMENT"
|
||||||
|
}
|
||||||
|
return "int unsigned"
|
||||||
|
case reflect.Int64:
|
||||||
if autoIncrease {
|
if autoIncrease {
|
||||||
return "bigint AUTO_INCREMENT"
|
return "bigint AUTO_INCREMENT"
|
||||||
}
|
}
|
||||||
return "bigint"
|
return "bigint"
|
||||||
|
case reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "bigint unsigned AUTO_INCREMENT"
|
||||||
|
}
|
||||||
|
return "bigint unsigned"
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "double"
|
return "double"
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
|
@ -216,13 +216,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
|||||||
}
|
}
|
||||||
fieldStruct, _ := modelType.FieldByName(column)
|
fieldStruct, _ := modelType.FieldByName(column)
|
||||||
var columns reflect.Value
|
var columns reflect.Value
|
||||||
if fieldStruct.Type.Kind() == reflect.Slice {
|
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
|
||||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
||||||
} else {
|
} else {
|
||||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
||||||
}
|
}
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
||||||
|
if column.Kind() == reflect.Ptr {
|
||||||
|
column = column.Elem()
|
||||||
|
}
|
||||||
if column.Kind() == reflect.Slice {
|
if column.Kind() == reflect.Slice {
|
||||||
for i := 0; i < column.Len(); i++ {
|
for i := 0; i < column.Len(); i++ {
|
||||||
columns = reflect.Append(columns, column.Index(i).Addr())
|
columns = reflect.Append(columns, column.Index(i).Addr())
|
||||||
|
9
scope.go
9
scope.go
@ -251,12 +251,7 @@ func (scope *Scope) TableName() string {
|
|||||||
return tabler.TableName(scope.db)
|
return tabler.TableName(scope.db)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.GetModelStruct().TableName != nil {
|
return scope.GetModelStruct().TableName(scope.db)
|
||||||
return scope.GetModelStruct().TableName(scope.db)
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Err(errors.New("wrong table name"))
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) QuotedTableName() (name string) {
|
func (scope *Scope) QuotedTableName() (name string) {
|
||||||
@ -278,7 +273,7 @@ func (scope *Scope) CombinedConditionSql() string {
|
|||||||
|
|
||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if field.Name == name {
|
if field.Name == name || field.DBName == name {
|
||||||
return field, true
|
return field, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -480,7 +480,7 @@ func (scope *Scope) createTable() *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
primaryKeys = append(primaryKeys, field.DBName)
|
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||||
}
|
}
|
||||||
scope.createJoinTable(field)
|
scope.createJoinTable(field)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user