Merge b38f14bfbd21d77faea7ac9b052faed8898d93ee into 0e22837c043e3f3405020fbb11ed4115accb5aaf

This commit is contained in:
Joel Feinstein 2014-11-25 07:15:46 +00:00
commit ee36367590
3 changed files with 49 additions and 16 deletions

View File

@ -106,6 +106,10 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
// You can also use an existing database connection handle
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db := gorm.Open("postgres", dbSql)
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
db.DB()

33
main.go
View File

@ -35,28 +35,41 @@ type DB struct {
values map[string]interface{}
}
func Open(dialect string, drivesources ...string) (DB, error) {
func Open(dialect string, args ...interface{}) (DB, error) {
var db DB
var err error
var driver = dialect
var source string
var dbSql sqlCommon
if len(drivesources) == 0 {
if len(args) == 0 {
err = errors.New("invalid database source")
}
switch value := args[0].(type) {
case string:
var driver = dialect
if len(args) == 1 {
source = value
} else if len(args) >= 2 {
driver = value
source = args[1].(string)
}
dbSql, err = sql.Open(driver, source)
case sqlCommon:
if reflect.TypeOf(value).Kind() == reflect.Ptr {
source = reflect.ValueOf(value).Elem().FieldByName("dsn").String()
} else {
if len(drivesources) == 1 {
source = drivesources[0]
} else if len(drivesources) >= 2 {
driver = drivesources[0]
source = drivesources[1]
source = reflect.ValueOf(value).FieldByName("dsn").String()
}
dbSql = value
}
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
logger: defaultLogger, callback: DefaultCallback, source: source,
values: map[string]interface{}{}}
db.db, err = sql.Open(driver, source)
db.db = dbSql
db.parent = &db
}
return db, err
}

View File

@ -554,6 +554,22 @@ func TestCompatibilityMode(t *testing.T) {
}
}
func TestOpenExistingDB(t *testing.T) {
DB.Save(&User{Name: "jnfeinstein"})
dialect := os.Getenv("GORM_DIALECT")
db, err := gorm.Open(dialect, DB.DB())
if err != nil {
t.Errorf("Should have wrapped the existing DB connection")
}
var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
t.Errorf("Should have found existing record")
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {