Merge pull request #1 from go-gorm/master

更新作者最新代码
This commit is contained in:
qifengzhang007 2020-10-23 22:47:01 +08:00 committed by GitHub
commit f982133c16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 18 deletions

View File

@ -326,6 +326,15 @@ func (db *DB) Count(count *int64) (tx *DB) {
defer tx.Statement.AddClause(clause.Select{}) defer tx.Statement.AddClause(clause.Select{})
} }
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(db.Statement.Clauses, "ORDER BY")
defer func() {
db.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx) tx.callbacks.Query().Execute(tx)
if tx.RowsAffected != 1 { if tx.RowsAffected != 1 {
@ -356,9 +365,13 @@ func (db *DB) Rows() (*sql.Rows, error) {
// Scan scan value to a struct // Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
currentLogger, newLogger := db.Logger, logger.Recorder.New() config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance() tx = db.getInstance()
tx.Logger = newLogger tx.Config = &config
if rows, err := tx.Rows(); err != nil { if rows, err := tx.Rows(); err != nil {
tx.AddError(err) tx.AddError(err)
} else { } else {

View File

@ -32,6 +32,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
stmt := &gorm.Statement{DB: m.DB} stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil { if m.DB.Statement != nil {
stmt.Table = m.DB.Statement.Table stmt.Table = m.DB.Statement.Table
stmt.TableExpr = m.DB.Statement.TableExpr
} }
if table, ok := value.(string); ok { if table, ok := value.(string); ok {
@ -161,6 +162,10 @@ func (m Migrator) CreateTable(values ...interface{}) error {
hasPrimaryKeyInDataType bool hasPrimaryKeyInDataType bool
) )
if stmt.TableExpr != nil {
values[0] = *stmt.TableExpr
}
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += "? ?" createTableSQL += "? ?"
@ -370,9 +375,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true alterColumn = true
} else { } else {
// has size in data type and not equal // has size in data type and not equal
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
alterColumn = true alterColumn = true
} }
} }

View File

@ -762,13 +762,15 @@ func (field *Field) setupValuerAndSetter() {
// pointer scanner // pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().AssignableTo(field.FieldType) { if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() { if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
err = field.Set(value, reflectV.Elem().Interface()) return field.Set(value, reflectV.Elem().Interface())
} }
} else { } else {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(value)

View File

@ -70,6 +70,11 @@ func TestCount(t *testing.T) {
var count4 int64 var count4 int64
if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count4)
}
var count5 int64
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count) t.Errorf("count with join, got error: %v, count %v", err, count)
} }
} }

View File

@ -6,11 +6,11 @@ require (
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
gorm.io/driver/mysql v1.0.2 gorm.io/driver/mysql v1.0.3
gorm.io/driver/postgres v1.0.4 gorm.io/driver/postgres v1.0.5
gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.5 gorm.io/driver/sqlserver v1.0.5
gorm.io/gorm v1.20.2 gorm.io/gorm v1.20.4
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -48,11 +48,13 @@ func TestMigrate(t *testing.T) {
} }
func TestSmartMigrateColumn(t *testing.T) { func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
type UserMigrateColumn struct { type UserMigrateColumn struct {
ID uint ID uint
Name string Name string
Salary float64 Salary float64
Birthday time.Time Birthday time.Time `gorm:"precision:4"`
} }
DB.Migrator().DropTable(&UserMigrateColumn{}) DB.Migrator().DropTable(&UserMigrateColumn{})
@ -78,15 +80,15 @@ func TestSmartMigrateColumn(t *testing.T) {
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
switch columnType.Name() { switch columnType.Name() {
case "name": case "name":
if length, _ := columnType.Length(); length != 0 && length != 128 { if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
t.Fatalf("name's length should be 128, but got %v", length) t.Fatalf("name's length should be 128, but got %v", length)
} }
case "salary": case "salary":
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
} }
case "birthday": case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("birthday's precision should be 2, but got %v", precision) t.Fatalf("birthday's precision should be 2, but got %v", precision)
} }
} }
@ -111,15 +113,15 @@ func TestSmartMigrateColumn(t *testing.T) {
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
switch columnType.Name() { switch columnType.Name() {
case "name": case "name":
if length, _ := columnType.Length(); length != 0 && length != 256 { if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
t.Fatalf("name's length should be 128, but got %v", length) t.Fatalf("name's length should be 128, but got %v", length)
} }
case "salary": case "salary":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("salary's precision should be 2, but got %v", precision) t.Fatalf("salary's precision should be 2, but got %v", precision)
} }
case "birthday": case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("birthday's precision should be 2, but got %v", precision) t.Fatalf("birthday's precision should be 2, but got %v", precision)
} }
} }