diff --git a/main.go b/main.go index 2b165c1f..65ed4fd5 100644 --- a/main.go +++ b/main.go @@ -113,9 +113,8 @@ func (s *DB) DataSource() string { func (s *DB) CopyDB() (*sql.DB, error) { if s.copyDB != nil { return s.copyDB, nil - } else { - return sql.Open(s.Dialect().GetName(), s.DataSource()) } + return sql.Open(s.Dialect().GetName(), s.DataSource()) } type closer interface { diff --git a/main_test.go b/main_test.go index dbcd8621..62e798b1 100644 --- a/main_test.go +++ b/main_test.go @@ -1061,15 +1061,24 @@ func TestBlockGlobalUpdate(t *testing.T) { func TestDB_DataSource(t *testing.T) { source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" - if DB.DataSource() != source { - t.Fatal(fmt.Sprintf("want '%s', but got '%s'", source, DB.DataSource())) + db,er := gorm.Open("postgres", source) + if er!=nil { + panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er)) + } + if db.DataSource() != source { + t.Fatal(fmt.Sprintf("want '%s', but got '%s'", source, db.DataSource())) } } func TestDB_CopyIn(t *testing.T) { - e:=DB.Exec("create table if not exists example(name varchar, age integer)").Error + source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + db,er := gorm.Open("postgres", source) + if er!=nil { + panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er)) + } + e:=db.Exec("create table if not exists example(name varchar, age integer)").Error defer func(){ - er := DB.Exec("drop table if exists example").Error + er := db.Exec("drop table if exists example").Error if er != nil { t.Fatal(e.Error()) } @@ -1077,7 +1086,7 @@ func TestDB_CopyIn(t *testing.T) { if e != nil { t.Fatal(e.Error()) } - defer DB.Exec("drop table") + defer db.Exec("drop table") var args = make([][]interface{}, 0) args = append(args, []interface{}{ "tom", 9, @@ -1086,7 +1095,7 @@ func TestDB_CopyIn(t *testing.T) { }, []interface{}{ "jim", 11, }) - e = DB.CopyIn(true, "example", args, "name", "age") + e = db.CopyIn(true, "example", args, "name", "age") if e != nil { t.Fatal(e.Error()) } @@ -1095,7 +1104,7 @@ func TestDB_CopyIn(t *testing.T) { Age int } var examples = make([]Example, 0) - e = DB.Raw("select * from example").Find(&examples).Error + e = db.Raw("select * from example").Find(&examples).Error if e != nil { t.Fatal(e.Error()) }