commit
f982133c16
@ -326,6 +326,15 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||
defer tx.Statement.AddClause(clause.Select{})
|
||||
}
|
||||
|
||||
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
|
||||
delete(db.Statement.Clauses, "ORDER BY")
|
||||
defer func() {
|
||||
db.Statement.Clauses["ORDER BY"] = orderByClause
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Dest = count
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
if tx.RowsAffected != 1 {
|
||||
@ -356,9 +365,13 @@ func (db *DB) Rows() (*sql.Rows, error) {
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
currentLogger, newLogger := db.Logger, logger.Recorder.New()
|
||||
config := *db.Config
|
||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||
config.Logger = newLogger
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Logger = newLogger
|
||||
tx.Config = &config
|
||||
|
||||
if rows, err := tx.Rows(); err != nil {
|
||||
tx.AddError(err)
|
||||
} else {
|
||||
|
@ -32,6 +32,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if m.DB.Statement != nil {
|
||||
stmt.Table = m.DB.Statement.Table
|
||||
stmt.TableExpr = m.DB.Statement.TableExpr
|
||||
}
|
||||
|
||||
if table, ok := value.(string); ok {
|
||||
@ -161,6 +162,10 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
hasPrimaryKeyInDataType bool
|
||||
)
|
||||
|
||||
if stmt.TableExpr != nil {
|
||||
values[0] = *stmt.TableExpr
|
||||
}
|
||||
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
createTableSQL += "? ?"
|
||||
@ -370,9 +375,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1)
|
||||
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
|
||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1)
|
||||
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1)
|
||||
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
@ -762,13 +762,15 @@ func (field *Field) setupValuerAndSetter() {
|
||||
// pointer scanner
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() {
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
err = field.Set(value, reflectV.Elem().Interface())
|
||||
return field.Set(value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else {
|
||||
fieldValue := field.ReflectValueOf(value)
|
||||
|
@ -70,6 +70,11 @@ func TestCount(t *testing.T) {
|
||||
|
||||
var count4 int64
|
||||
if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
||||
t.Errorf("count with join, got error: %v, count %v", err, count4)
|
||||
}
|
||||
|
||||
var count5 int64
|
||||
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
|
||||
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||
}
|
||||
}
|
||||
|
@ -6,11 +6,11 @@ require (
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/lib/pq v1.6.0
|
||||
gorm.io/driver/mysql v1.0.2
|
||||
gorm.io/driver/postgres v1.0.4
|
||||
gorm.io/driver/mysql v1.0.3
|
||||
gorm.io/driver/postgres v1.0.5
|
||||
gorm.io/driver/sqlite v1.1.3
|
||||
gorm.io/driver/sqlserver v1.0.5
|
||||
gorm.io/gorm v1.20.2
|
||||
gorm.io/gorm v1.20.4
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
@ -48,11 +48,13 @@ func TestMigrate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSmartMigrateColumn(t *testing.T) {
|
||||
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
|
||||
|
||||
type UserMigrateColumn struct {
|
||||
ID uint
|
||||
Name string
|
||||
Salary float64
|
||||
Birthday time.Time
|
||||
Birthday time.Time `gorm:"precision:4"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&UserMigrateColumn{})
|
||||
@ -78,15 +80,15 @@ func TestSmartMigrateColumn(t *testing.T) {
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); length != 0 && length != 128 {
|
||||
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||
if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
@ -111,15 +113,15 @@ func TestSmartMigrateColumn(t *testing.T) {
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); length != 0 && length != 256 {
|
||||
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v", precision)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user