diff --git a/.gitignore b/.gitignore index 01dc5ce0..990073fd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ documents _book +.idea/workspace.xml diff --git a/main.go b/main.go index 17c75ed3..f4aa5573 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/lib/pq" "reflect" "strings" "sync" @@ -16,6 +17,10 @@ type DB struct { Error error RowsAffected int64 + // for copyIn + dataSource string + copyDB *sql.DB + // single db db SQLCommon blockGlobalUpdate bool @@ -70,10 +75,11 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } db = &DB{ - db: dbSQL, - logger: defaultLogger, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), + db: dbSQL, + logger: defaultLogger, + callbacks: DefaultCallback, + dialect: newDialect(dialect, dbSQL), + dataSource: source, } db.parent = db if err != nil { @@ -96,6 +102,21 @@ func (s *DB) New() *DB { return clone } +// DataSource returns its opened dataSource: +// gorm.Open("postgres", "host=localhost user=root dbname=test sslmode=disable password=123") +// Then dataSource = "host=localhost user=root dbname=test sslmode=disable password=123" +func (s *DB) DataSource() string { + return s.dataSource +} + +// CopyDB returns its origin db engine whose driver is github.com/lib/pq +func (s *DB) CopyDB() (*sql.DB, error) { + if s.copyDB != nil { + return s.copyDB, nil + } + return sql.Open(s.Dialect().GetName(), s.DataSource()) +} + type closer interface { Close() error } @@ -462,6 +483,71 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB { return scope.Exec().db } +// CopyIn execute a big amount of insertions +// closeAfterUsed: After finished or stopped by accident, the copyDB close itself if 'closeAfterUsed' has been set true +// table: which table you want to operate +// args: what you want to insert +// columns: what column you want to operate +// For instance: +// args = append(args, []interface{}{ +// "tom", 9, +// }, []interface{}{ +// "sara", 10, +// }, []interface{}{ +// "jim", 11, +// }) +// db.CopyIn(false, "user", args, "name","age") +// This stands for 'insert into user(name,age) values('tom', 9),('sara',10),('jim',19)' +// or 'COPY 'user' ('name','age') FROM STDIN`' +func (s *DB) CopyIn(closeAfterUsed bool, table string, args [][]interface{}, columns ...string) error { + if s.Dialect().GetName() != "postgres" { + return errors.New("CopyIn only supports postgres") + } + + pdb, err := s.CopyDB() + if closeAfterUsed { + defer pdb.Close() + } + + if err != nil { + return err + } + txn, err := pdb.Begin() + if err != nil { + return err + } + + stmt, err := txn.Prepare(pq.CopyIn(table, columns...)) + if err != nil { + return err + } + for _, v := range args { + _, err = stmt.Exec(v...) + if err != nil { + txn.Rollback() + return err + } + } + + _, err = stmt.Exec() + if err != nil { + txn.Rollback() + return err + } + + err = stmt.Close() + if err != nil { + txn.Rollback() + return err + } + err = txn.Commit() + if err != nil { + txn.Rollback() + return err + } + return nil +} + // Model specify the model you would like to run db operations // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") diff --git a/main_test.go b/main_test.go index 94d2fa39..03eaf8e7 100644 --- a/main_test.go +++ b/main_test.go @@ -1059,6 +1059,61 @@ func TestBlockGlobalUpdate(t *testing.T) { } } +func TestDB_DataSource(t *testing.T) { + source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + db, er := gorm.Open("postgres", source) + if er != nil { + panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er)) + } + if db.DataSource() != source { + t.Fatal(fmt.Sprintf("want '%s', but got '%s'", source, db.DataSource())) + } + +} +func TestDB_CopyIn(t *testing.T) { + source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + db, er := gorm.Open("postgres", source) + if er != nil { + panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er)) + } + e := db.Exec("create table if not exists example(name varchar, age integer)").Error + defer func() { + er := db.Exec("drop table if exists example").Error + if er != nil { + t.Fatal(e.Error()) + } + }() + if e != nil { + t.Fatal(e.Error()) + } + defer db.Exec("drop table") + var args = make([][]interface{}, 0) + args = append(args, []interface{}{ + "tom", 9, + }, []interface{}{ + "sara", 10, + }, []interface{}{ + "jim", 11, + }) + e = db.CopyIn(true, "example", args, "name", "age") + if e != nil { + t.Fatal(e.Error()) + } + type Example struct { + Name string + Age int + } + var examples = make([]Example, 0) + e = db.Raw("select * from example").Find(&examples).Error + if e != nil { + t.Fatal(e.Error()) + } + if len(examples) == 0 { + t.Fatal("examples length wants more than 3, but got 0") + } + fmt.Println(examples) +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ {