diff --git a/callback_create.go b/callback_create.go index b21df08b..7f21ed6a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -70,18 +70,24 @@ func Create(scope *Scope) { id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil { + if primaryField != nil && primaryField.IsBlank { scope.Err(scope.SetColumn(primaryField, id)) } } } } else { if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil { + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { scope.db.RowsAffected, _ = results.RowsAffected() + } else { + scope.Err(err) + } + } else { + if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { + scope.db.RowsAffected = 1 + } else { + scope.Err(err) } - } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { - scope.db.RowsAffected = 1 } } } diff --git a/doc/development.md b/doc/development.md index 674cfc43..08166661 100644 --- a/doc/development.md +++ b/doc/development.md @@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) -[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) +[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go) [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) diff --git a/main.go b/main.go index bf8acbae..181722fd 100644 --- a/main.go +++ b/main.go @@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join } } } - -func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) { - s.NewScope(source).GetModelStruct().TableName = handler -} diff --git a/main_test.go b/main_test.go index b547534c..0dc5e337 100644 --- a/main_test.go +++ b/main_test.go @@ -61,6 +61,19 @@ func init() { runMigration() } +func TestStringPrimaryKey(t *testing.T) { + type UUIDStruct struct { + ID string `gorm:"primary_key"` + Name string + } + DB.AutoMigrate(&UUIDStruct{}) + + data := UUIDStruct{ID: "uuid", Name: "hello"} + if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { + t.Errorf("string primary key should not be populated") + } +} + func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { diff --git a/model_struct.go b/model_struct.go index a70489fc..10423ae2 100644 --- a/model_struct.go +++ b/model_struct.go @@ -13,11 +13,19 @@ import ( var modelStructs = map[reflect.Type]*ModelStruct{} +var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { + return defaultTableName +} + type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - TableName func(*DB) string + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + defaultTableName string +} + +func (s ModelStruct) TableName(db *DB) string { + return DefaultTableNameHandler(db, s.defaultTableName) } type StructField struct { @@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Set tablename - if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { - if results := fm.Call([]reflect.Value{}); len(results) > 0 { - if name, ok := results[0].Interface().(string); ok { - modelStruct.TableName = func(*DB) string { - return name - } - } - } + type tabler interface { + TableName() string + } + + if tabler, ok := reflect.New(scopeType).Interface().(interface { + TableName() string + }); ok { + modelStruct.defaultTableName = tabler.TableName() } else { name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { @@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStruct.TableName = func(*DB) string { - return name - } + modelStruct.defaultTableName = name } // Get all fields diff --git a/mysql.go b/mysql.go index e37a23e0..a5e4a459 100644 --- a/mysql.go +++ b/mysql.go @@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: if autoIncrease { return "int AUTO_INCREMENT" } return "int" - case reflect.Int64, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "int unsigned AUTO_INCREMENT" + } + return "int unsigned" + case reflect.Int64: if autoIncrease { return "bigint AUTO_INCREMENT" } return "bigint" + case reflect.Uint64: + if autoIncrease { + return "bigint unsigned AUTO_INCREMENT" + } + return "bigint unsigned" case reflect.Float32, reflect.Float64: return "double" case reflect.String: diff --git a/preload.go b/preload.go index add077ab..f1c0fae5 100644 --- a/preload.go +++ b/preload.go @@ -216,13 +216,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope { } fieldStruct, _ := modelType.FieldByName(column) var columns reflect.Value - if fieldStruct.Type.Kind() == reflect.Slice { + if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() } else { columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() } for i := 0; i < values.Len(); i++ { column := reflect.Indirect(values.Index(i)).FieldByName(column) + if column.Kind() == reflect.Ptr { + column = column.Elem() + } if column.Kind() == reflect.Slice { for i := 0; i < column.Len(); i++ { columns = reflect.Append(columns, column.Index(i).Addr()) diff --git a/scope.go b/scope.go index 54bf5c84..11bad777 100644 --- a/scope.go +++ b/scope.go @@ -251,12 +251,7 @@ func (scope *Scope) TableName() string { return tabler.TableName(scope.db) } - if scope.GetModelStruct().TableName != nil { - return scope.GetModelStruct().TableName(scope.db) - } - - scope.Err(errors.New("wrong table name")) - return "" + return scope.GetModelStruct().TableName(scope.db) } func (scope *Scope) QuotedTableName() (name string) { @@ -278,7 +273,7 @@ func (scope *Scope) CombinedConditionSql() string { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { for _, field := range scope.Fields() { - if field.Name == name { + if field.Name == name || field.DBName == name { return field, true } } diff --git a/scope_private.go b/scope_private.go index 2f3797b7..d4d1343e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -480,7 +480,7 @@ func (scope *Scope) createTable() *Scope { } if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, field.DBName) + primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) } scope.createJoinTable(field) }