diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b918b4b5..8b4a0722 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "fmt" "math/rand" "reflect" @@ -9,6 +10,7 @@ import ( "time" "gorm.io/driver/postgres" + "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" @@ -72,6 +74,43 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } + +} + +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MigrateInt struct { + Int8 int8 + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + } func TestAutoMigrateSelfReferential(t *testing.T) { @@ -400,7 +439,7 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) @@ -830,11 +869,11 @@ func TestUniqueColumn(t *testing.T) { value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) + } func findColumnType(dest interface{}, columnName string) ( - foundColumn gorm.ColumnType, err error, -) { + foundColumn gorm.ColumnType, err error) { columnTypes, err := DB.Migrator().ColumnTypes(dest) if err != nil { err = fmt.Errorf("ColumnTypes err:%v", err) @@ -884,116 +923,3 @@ func TestInvalidCachedPlan(t *testing.T) { t.Errorf("AutoMigrate err:%v", err) } } - -func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { - type DiffType struct { - ID uint - Name string `gorm:"type:varchar(20)"` - } - - type DiffType1 struct { - ID uint - Name string `gorm:"type:text"` - } - - var err error - DB.Migrator().DropTable(&DiffType{}) - - err = DB.AutoMigrate(&DiffType{}) - if err != nil { - t.Errorf("AutoMigrate err:%v", err) - } - - ct, err := findColumnType(&DiffType{}, "name") - if err != nil { - t.Errorf("findColumnType err:%v", err) - } - - AssertEqual(t, "varchar", strings.ToLower(ct.DatabaseTypeName())) - - err = DB.Table("diff_types").AutoMigrate(&DiffType1{}) - if err != nil { - t.Errorf("AutoMigrate err:%v", err) - } - - ct, err = findColumnType(&DiffType{}, "name") - if err != nil { - t.Errorf("findColumnType err:%v", err) - } - - AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) -} - -func TestMigrateArrayTypeModel(t *testing.T) { - if DB.Dialector.Name() != "postgres" { - return - } - - type ArrayTypeModel struct { - ID uint - Number string `gorm:"type:varchar(51);NOT NULL"` - TextArray []string `gorm:"type:text[];NOT NULL"` - NestedTextArray [][]string `gorm:"type:text[][]"` - NestedIntArray [][]int64 `gorm:"type:integer[3][3]"` - } - - var err error - DB.Migrator().DropTable(&ArrayTypeModel{}) - - err = DB.AutoMigrate(&ArrayTypeModel{}) - AssertEqual(t, nil, err) - - ct, err := findColumnType(&ArrayTypeModel{}, "number") - AssertEqual(t, nil, err) - AssertEqual(t, "varchar", ct.DatabaseTypeName()) - - ct, err = findColumnType(&ArrayTypeModel{}, "text_array") - AssertEqual(t, nil, err) - AssertEqual(t, "text[]", ct.DatabaseTypeName()) - - ct, err = findColumnType(&ArrayTypeModel{}, "nested_text_array") - AssertEqual(t, nil, err) - AssertEqual(t, "text[]", ct.DatabaseTypeName()) - - ct, err = findColumnType(&ArrayTypeModel{}, "nested_int_array") - AssertEqual(t, nil, err) - AssertEqual(t, "integer[]", ct.DatabaseTypeName()) -} - -func TestMigrateSameEmbeddedFieldName(t *testing.T) { - type UserStat struct { - GroundDestroyCount int - } - - type GameUser struct { - gorm.Model - StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` - } - - type UserStat1 struct { - GroundDestroyCount string - } - - type GroundRate struct { - GroundDestroyCount int - } - - type GameUser1 struct { - gorm.Model - StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` - GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` - } - - DB.Migrator().DropTable(&GameUser{}) - err := DB.AutoMigrate(&GameUser{}) - AssertEqual(t, nil, err) - - err = DB.Table("game_users").AutoMigrate(&GameUser1{}) - AssertEqual(t, nil, err) - - _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") - AssertEqual(t, nil, err) - - _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") - AssertEqual(t, nil, err) -} diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +}