Merge pull request #1 from go-gorm/master

更新作者最新代码
This commit is contained in:
qifengzhang007 2020-10-23 22:47:01 +08:00 committed by GitHub
commit f982133c16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 18 deletions

View File

@ -326,6 +326,15 @@ func (db *DB) Count(count *int64) (tx *DB) {
defer tx.Statement.AddClause(clause.Select{})
}
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(db.Statement.Clauses, "ORDER BY")
defer func() {
db.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx)
if tx.RowsAffected != 1 {
@ -356,9 +365,13 @@ func (db *DB) Rows() (*sql.Rows, error) {
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
currentLogger, newLogger := db.Logger, logger.Recorder.New()
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance()
tx.Logger = newLogger
tx.Config = &config
if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {

View File

@ -32,6 +32,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil {
stmt.Table = m.DB.Statement.Table
stmt.TableExpr = m.DB.Statement.TableExpr
}
if table, ok := value.(string); ok {
@ -161,6 +162,10 @@ func (m Migrator) CreateTable(values ...interface{}) error {
hasPrimaryKeyInDataType bool
)
if stmt.TableExpr != nil {
values[0] = *stmt.TableExpr
}
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += "? ?"
@ -370,9 +375,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true
} else {
// has size in data type and not equal
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
alterColumn = true
}
}

View File

@ -762,13 +762,15 @@ func (field *Field) setupValuerAndSetter() {
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().AssignableTo(field.FieldType) {
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
err = field.Set(value, reflectV.Elem().Interface())
return field.Set(value, reflectV.Elem().Interface())
}
} else {
fieldValue := field.ReflectValueOf(value)

View File

@ -70,6 +70,11 @@ func TestCount(t *testing.T) {
var count4 int64
if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count4)
}
var count5 int64
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count)
}
}

View File

@ -6,11 +6,11 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
gorm.io/driver/mysql v1.0.2
gorm.io/driver/postgres v1.0.4
gorm.io/driver/mysql v1.0.3
gorm.io/driver/postgres v1.0.5
gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.5
gorm.io/gorm v1.20.2
gorm.io/gorm v1.20.4
)
replace gorm.io/gorm => ../

View File

@ -48,11 +48,13 @@ func TestMigrate(t *testing.T) {
}
func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
Name string
Salary float64
Birthday time.Time
Birthday time.Time `gorm:"precision:4"`
}
DB.Migrator().DropTable(&UserMigrateColumn{})
@ -78,15 +80,15 @@ func TestSmartMigrateColumn(t *testing.T) {
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); length != 0 && length != 128 {
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
@ -111,15 +113,15 @@ func TestSmartMigrateColumn(t *testing.T) {
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); length != 0 && length != 256 {
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("salary's precision should be 2, but got %v", precision)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}