Add CopyIn for postgres dialect

This commit is contained in:
fwhez 2018-11-12 12:44:09 +08:00
parent 472c70caa4
commit 09318cf444
3 changed files with 146 additions and 4 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
documents documents
_book _book
.idea/workspace.xml

86
main.go
View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"github.com/lib/pq"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@ -16,6 +17,10 @@ type DB struct {
Error error Error error
RowsAffected int64 RowsAffected int64
// for copyIn
dataSource string
copyDB *sql.DB
// single db // single db
db SQLCommon db SQLCommon
blockGlobalUpdate bool blockGlobalUpdate bool
@ -74,6 +79,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
logger: defaultLogger, logger: defaultLogger,
callbacks: DefaultCallback, callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL), dialect: newDialect(dialect, dbSQL),
dataSource: source,
} }
db.parent = db db.parent = db
if err != nil { if err != nil {
@ -96,6 +102,21 @@ func (s *DB) New() *DB {
return clone return clone
} }
// DataSource returns its opened dataSource:
// gorm.Open("postgres", "host=localhost user=root dbname=test sslmode=disable password=123")
// Then dataSource = "host=localhost user=root dbname=test sslmode=disable password=123"
func (s *DB) DataSource() string {
return s.dataSource
}
// CopyDB returns its origin db engine whose driver is github.com/lib/pq
func (s *DB) CopyDB() (*sql.DB, error) {
if s.copyDB != nil {
return s.copyDB, nil
}
return sql.Open(s.Dialect().GetName(), s.DataSource())
}
type closer interface { type closer interface {
Close() error Close() error
} }
@ -462,6 +483,71 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB {
return scope.Exec().db return scope.Exec().db
} }
// CopyIn execute a big amount of insertions
// closeAfterUsed: After finished or stopped by accident, the copyDB close itself if 'closeAfterUsed' has been set true
// table: which table you want to operate
// args: what you want to insert
// columns: what column you want to operate
// For instance:
// args = append(args, []interface{}{
// "tom", 9,
// }, []interface{}{
// "sara", 10,
// }, []interface{}{
// "jim", 11,
// })
// db.CopyIn(false, "user", args, "name","age")
// This stands for 'insert into user(name,age) values('tom', 9),('sara',10),('jim',19)'
// or 'COPY 'user' ('name','age') FROM STDIN`'
func (s *DB) CopyIn(closeAfterUsed bool, table string, args [][]interface{}, columns ...string) error {
if s.Dialect().GetName() != "postgres" {
return errors.New("CopyIn only supports postgres")
}
pdb, err := s.CopyDB()
if closeAfterUsed {
defer pdb.Close()
}
if err != nil {
return err
}
txn, err := pdb.Begin()
if err != nil {
return err
}
stmt, err := txn.Prepare(pq.CopyIn(table, columns...))
if err != nil {
return err
}
for _, v := range args {
_, err = stmt.Exec(v...)
if err != nil {
txn.Rollback()
return err
}
}
_, err = stmt.Exec()
if err != nil {
txn.Rollback()
return err
}
err = stmt.Close()
if err != nil {
txn.Rollback()
return err
}
err = txn.Commit()
if err != nil {
txn.Rollback()
return err
}
return nil
}
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
// // update all users's name to `hello` // // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello") // db.Model(&User{}).Update("name", "hello")

View File

@ -1059,6 +1059,61 @@ func TestBlockGlobalUpdate(t *testing.T) {
} }
} }
func TestDB_DataSource(t *testing.T) {
source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
db, er := gorm.Open("postgres", source)
if er != nil {
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er))
}
if db.DataSource() != source {
t.Fatal(fmt.Sprintf("want '%s', but got '%s'", source, db.DataSource()))
}
}
func TestDB_CopyIn(t *testing.T) {
source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
db, er := gorm.Open("postgres", source)
if er != nil {
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er))
}
e := db.Exec("create table if not exists example(name varchar, age integer)").Error
defer func() {
er := db.Exec("drop table if exists example").Error
if er != nil {
t.Fatal(e.Error())
}
}()
if e != nil {
t.Fatal(e.Error())
}
defer db.Exec("drop table")
var args = make([][]interface{}, 0)
args = append(args, []interface{}{
"tom", 9,
}, []interface{}{
"sara", 10,
}, []interface{}{
"jim", 11,
})
e = db.CopyIn(true, "example", args, "name", "age")
if e != nil {
t.Fatal(e.Error())
}
type Example struct {
Name string
Age int
}
var examples = make([]Example, 0)
e = db.Raw("select * from example").Find(&examples).Error
if e != nil {
t.Fatal(e.Error())
}
if len(examples) == 0 {
t.Fatal("examples length wants more than 3, but got 0")
}
fmt.Println(examples)
}
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
b.N = 2000 b.N = 2000
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {