From 16b210d1bd60db0db4d22558780b4802834ee184 Mon Sep 17 00:00:00 2001 From: fwhez <1728565484@qq.com> Date: Fri, 9 Nov 2018 18:16:36 +0800 Subject: [PATCH] + add CopyIn() method for gorm.DB --- main.go | 81 +++++++++++++++++++++++++++++++++++++++++++++ main_copyIn_test.go | 49 +++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 main_copyIn_test.go diff --git a/main.go b/main.go index 17c75ed3..5a1aec6d 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/lib/pq" "reflect" "strings" "sync" @@ -16,6 +17,10 @@ type DB struct { Error error RowsAffected int64 + // for copyIn + dataSource string + copyDB *sql.DB + // single db db SQLCommon blockGlobalUpdate bool @@ -74,6 +79,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { logger: defaultLogger, callbacks: DefaultCallback, dialect: newDialect(dialect, dbSQL), + dataSource: source, } db.parent = db if err != nil { @@ -96,6 +102,23 @@ func (s *DB) New() *DB { return clone } +// DataSource returns its opened dataSource: +// gorm.Open("postgres", "host=localhost user=root dbname=test sslmode=disable password=123") +// Then dataSource = "host=localhost user=root dbname=test sslmode=disable password=123" +func (s *DB) DataSource() string{ + return s.dataSource +} + +// CopyDB returns its origin db engine whose driver is github.com/lib/pq +func (s *DB) CopyDB() (*sql.DB,error){ + if s.copyDB!=nil { + return s.copyDB,nil + }else{ + return sql.Open(s.Dialect().GetName(), s.DataSource()) + } +} + + type closer interface { Close() error } @@ -462,6 +485,64 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB { return scope.Exec().db } +// CopyIn execute a big amount of insertions +// closeAfterUsed: After finished or stopped by accident, the copyDB close itself if 'closeAfterUsed' has been set true +// table: which table you want to operate +// args: what you want to insert +// columns: what column you want to operate +// For instance: +// db.CopyIn(false, "user", [][]string{["tom", 9],["sara", 10],["jim", 19]}, "name","age") +// This stands for 'insert into user(name,age) values('tom', 9),('sara',10),('jim',19)' +// or 'COPY 'user' ('name','age') FROM STDIN`' +func (s *DB) CopyIn(closeAfterUsed bool,table string, args [][]interface{}, columns ...string) error { + if s.Dialect().GetName() != "postgres" { + return errors.New("CopyIn only supports postgres") + } + + pdb, err := s.CopyDB() + if closeAfterUsed{ + defer pdb.Close() + } + + if err != nil { + return err + } + txn, err := pdb.Begin() + if err != nil { + return err + } + + stmt, err := txn.Prepare(pq.CopyIn(table, columns...)) + if err != nil { + return err + } + for _,v:= range args { + _, err = stmt.Exec(v...) + if err != nil { + txn.Rollback() + return err + } + } + + _, err = stmt.Exec() + if err != nil { + txn.Rollback() + return err + } + + err = stmt.Close() + if err != nil { + txn.Rollback() + return err + } + err = txn.Commit() + if err != nil { + txn.Rollback() + return err + } + return nil +} + // Model specify the model you would like to run db operations // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") diff --git a/main_copyIn_test.go b/main_copyIn_test.go new file mode 100644 index 00000000..372ae299 --- /dev/null +++ b/main_copyIn_test.go @@ -0,0 +1,49 @@ +package gorm + +import ( + "fmt" + "github.com/fwhezfwhez/gorm" + "testing" + +) + +func TestDB_DataSource(t *testing.T) { + source := fmt.Sprintf("host=%s user=%s dbname=%s sslmode=%s password=%s", + "localhost", "postgres", "test", "disable", "123") + db, er := Open("postgres", source) + if er != nil { + t.Fatal(er.Error()) + } + fmt.Println(db.DataSource()) +} +func TestDB_CopyIn(t *testing.T) { + source := fmt.Sprintf("host=%s user=%s dbname=%s sslmode=%s password=%s", + "localhost", "postgres", "test", "disable", "123") + db, er := gorm.Open("postgres", source) + if er != nil { + t.Fatal(er.Error()) + } + db.Exec("create table if not exists example(name varchar, age integer)") + var args = make([][]interface{}, 0) + args = append(args, []interface{}{ + "tom", 9, + }, []interface{}{ + "sara", 10, + }, []interface{}{ + "jim", 11, + }) + e := db.CopyIn(true, "example", args, "name", "age") + if e != nil { + t.Fatal(e.Error()) + } + type Example struct{ + Name string + Age int + } + var examples = make([]Example,0) + e=db.Raw("select * from example").Find(&examples).Error + if e!=nil { + t.Fatal(e.Error()) + } + fmt.Println(examples) +}