Add CopyIn for postgres dialect
This commit is contained in:
parent
472c70caa4
commit
09318cf444
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
documents
|
documents
|
||||||
_book
|
_book
|
||||||
|
.idea/workspace.xml
|
||||||
|
86
main.go
86
main.go
@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/lib/pq"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -16,6 +17,10 @@ type DB struct {
|
|||||||
Error error
|
Error error
|
||||||
RowsAffected int64
|
RowsAffected int64
|
||||||
|
|
||||||
|
// for copyIn
|
||||||
|
dataSource string
|
||||||
|
copyDB *sql.DB
|
||||||
|
|
||||||
// single db
|
// single db
|
||||||
db SQLCommon
|
db SQLCommon
|
||||||
blockGlobalUpdate bool
|
blockGlobalUpdate bool
|
||||||
@ -74,6 +79,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
callbacks: DefaultCallback,
|
callbacks: DefaultCallback,
|
||||||
dialect: newDialect(dialect, dbSQL),
|
dialect: newDialect(dialect, dbSQL),
|
||||||
|
dataSource: source,
|
||||||
}
|
}
|
||||||
db.parent = db
|
db.parent = db
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -96,6 +102,21 @@ func (s *DB) New() *DB {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DataSource returns its opened dataSource:
|
||||||
|
// gorm.Open("postgres", "host=localhost user=root dbname=test sslmode=disable password=123")
|
||||||
|
// Then dataSource = "host=localhost user=root dbname=test sslmode=disable password=123"
|
||||||
|
func (s *DB) DataSource() string {
|
||||||
|
return s.dataSource
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyDB returns its origin db engine whose driver is github.com/lib/pq
|
||||||
|
func (s *DB) CopyDB() (*sql.DB, error) {
|
||||||
|
if s.copyDB != nil {
|
||||||
|
return s.copyDB, nil
|
||||||
|
}
|
||||||
|
return sql.Open(s.Dialect().GetName(), s.DataSource())
|
||||||
|
}
|
||||||
|
|
||||||
type closer interface {
|
type closer interface {
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
@ -462,6 +483,71 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
|||||||
return scope.Exec().db
|
return scope.Exec().db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyIn execute a big amount of insertions
|
||||||
|
// closeAfterUsed: After finished or stopped by accident, the copyDB close itself if 'closeAfterUsed' has been set true
|
||||||
|
// table: which table you want to operate
|
||||||
|
// args: what you want to insert
|
||||||
|
// columns: what column you want to operate
|
||||||
|
// For instance:
|
||||||
|
// args = append(args, []interface{}{
|
||||||
|
// "tom", 9,
|
||||||
|
// }, []interface{}{
|
||||||
|
// "sara", 10,
|
||||||
|
// }, []interface{}{
|
||||||
|
// "jim", 11,
|
||||||
|
// })
|
||||||
|
// db.CopyIn(false, "user", args, "name","age")
|
||||||
|
// This stands for 'insert into user(name,age) values('tom', 9),('sara',10),('jim',19)'
|
||||||
|
// or 'COPY 'user' ('name','age') FROM STDIN`'
|
||||||
|
func (s *DB) CopyIn(closeAfterUsed bool, table string, args [][]interface{}, columns ...string) error {
|
||||||
|
if s.Dialect().GetName() != "postgres" {
|
||||||
|
return errors.New("CopyIn only supports postgres")
|
||||||
|
}
|
||||||
|
|
||||||
|
pdb, err := s.CopyDB()
|
||||||
|
if closeAfterUsed {
|
||||||
|
defer pdb.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
txn, err := pdb.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := txn.Prepare(pq.CopyIn(table, columns...))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, v := range args {
|
||||||
|
_, err = stmt.Exec(v...)
|
||||||
|
if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stmt.Exec()
|
||||||
|
if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = stmt.Close()
|
||||||
|
if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = txn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Model specify the model you would like to run db operations
|
// Model specify the model you would like to run db operations
|
||||||
// // update all users's name to `hello`
|
// // update all users's name to `hello`
|
||||||
// db.Model(&User{}).Update("name", "hello")
|
// db.Model(&User{}).Update("name", "hello")
|
||||||
|
55
main_test.go
55
main_test.go
@ -1059,6 +1059,61 @@ func TestBlockGlobalUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDB_DataSource(t *testing.T) {
|
||||||
|
source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
||||||
|
db, er := gorm.Open("postgres", source)
|
||||||
|
if er != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er))
|
||||||
|
}
|
||||||
|
if db.DataSource() != source {
|
||||||
|
t.Fatal(fmt.Sprintf("want '%s', but got '%s'", source, db.DataSource()))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
func TestDB_CopyIn(t *testing.T) {
|
||||||
|
source := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
||||||
|
db, er := gorm.Open("postgres", source)
|
||||||
|
if er != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", er))
|
||||||
|
}
|
||||||
|
e := db.Exec("create table if not exists example(name varchar, age integer)").Error
|
||||||
|
defer func() {
|
||||||
|
er := db.Exec("drop table if exists example").Error
|
||||||
|
if er != nil {
|
||||||
|
t.Fatal(e.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if e != nil {
|
||||||
|
t.Fatal(e.Error())
|
||||||
|
}
|
||||||
|
defer db.Exec("drop table")
|
||||||
|
var args = make([][]interface{}, 0)
|
||||||
|
args = append(args, []interface{}{
|
||||||
|
"tom", 9,
|
||||||
|
}, []interface{}{
|
||||||
|
"sara", 10,
|
||||||
|
}, []interface{}{
|
||||||
|
"jim", 11,
|
||||||
|
})
|
||||||
|
e = db.CopyIn(true, "example", args, "name", "age")
|
||||||
|
if e != nil {
|
||||||
|
t.Fatal(e.Error())
|
||||||
|
}
|
||||||
|
type Example struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
var examples = make([]Example, 0)
|
||||||
|
e = db.Raw("select * from example").Find(&examples).Error
|
||||||
|
if e != nil {
|
||||||
|
t.Fatal(e.Error())
|
||||||
|
}
|
||||||
|
if len(examples) == 0 {
|
||||||
|
t.Fatal("examples length wants more than 3, but got 0")
|
||||||
|
}
|
||||||
|
fmt.Println(examples)
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user