Merge b38f14bfbd21d77faea7ac9b052faed8898d93ee into 0e22837c043e3f3405020fbb11ed4115accb5aaf
This commit is contained in:
commit
ee36367590
@ -106,6 +106,10 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
|
||||
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
|
||||
|
||||
// You can also use an existing database connection handle
|
||||
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db := gorm.Open("postgres", dbSql)
|
||||
|
||||
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
||||
db.DB()
|
||||
|
||||
|
45
main.go
45
main.go
@ -35,28 +35,41 @@ type DB struct {
|
||||
values map[string]interface{}
|
||||
}
|
||||
|
||||
func Open(dialect string, drivesources ...string) (DB, error) {
|
||||
func Open(dialect string, args ...interface{}) (DB, error) {
|
||||
var db DB
|
||||
var err error
|
||||
var driver = dialect
|
||||
var source string
|
||||
var dbSql sqlCommon
|
||||
|
||||
if len(drivesources) == 0 {
|
||||
if len(args) == 0 {
|
||||
err = errors.New("invalid database source")
|
||||
} else {
|
||||
if len(drivesources) == 1 {
|
||||
source = drivesources[0]
|
||||
} else if len(drivesources) >= 2 {
|
||||
driver = drivesources[0]
|
||||
source = drivesources[1]
|
||||
}
|
||||
|
||||
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
|
||||
logger: defaultLogger, callback: DefaultCallback, source: source,
|
||||
values: map[string]interface{}{}}
|
||||
db.db, err = sql.Open(driver, source)
|
||||
db.parent = &db
|
||||
}
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
var driver = dialect
|
||||
if len(args) == 1 {
|
||||
source = value
|
||||
} else if len(args) >= 2 {
|
||||
driver = value
|
||||
source = args[1].(string)
|
||||
}
|
||||
dbSql, err = sql.Open(driver, source)
|
||||
case sqlCommon:
|
||||
if reflect.TypeOf(value).Kind() == reflect.Ptr {
|
||||
source = reflect.ValueOf(value).Elem().FieldByName("dsn").String()
|
||||
} else {
|
||||
source = reflect.ValueOf(value).FieldByName("dsn").String()
|
||||
}
|
||||
dbSql = value
|
||||
}
|
||||
|
||||
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
|
||||
logger: defaultLogger, callback: DefaultCallback, source: source,
|
||||
values: map[string]interface{}{}}
|
||||
db.db = dbSql
|
||||
db.parent = &db
|
||||
|
||||
return db, err
|
||||
}
|
||||
|
||||
|
16
main_test.go
16
main_test.go
@ -554,6 +554,22 @@ func TestCompatibilityMode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenExistingDB(t *testing.T) {
|
||||
DB.Save(&User{Name: "jnfeinstein"})
|
||||
dialect := os.Getenv("GORM_DIALECT")
|
||||
|
||||
db, err := gorm.Open(dialect, DB.DB())
|
||||
if err != nil {
|
||||
t.Errorf("Should have wrapped the existing DB connection")
|
||||
}
|
||||
|
||||
var user User
|
||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
|
||||
t.Errorf("Should have found existing record")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
b.N = 2000
|
||||
for x := 0; x < b.N; x++ {
|
||||
|
Loading…
x
Reference in New Issue
Block a user