+ add CopyIn() method for gorm.DB
This commit is contained in:
parent
472c70caa4
commit
16b210d1bd
81
main.go
81
main.go
@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/lib/pq"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -16,6 +17,10 @@ type DB struct {
|
||||
Error error
|
||||
RowsAffected int64
|
||||
|
||||
// for copyIn
|
||||
dataSource string
|
||||
copyDB *sql.DB
|
||||
|
||||
// single db
|
||||
db SQLCommon
|
||||
blockGlobalUpdate bool
|
||||
@ -74,6 +79,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||
logger: defaultLogger,
|
||||
callbacks: DefaultCallback,
|
||||
dialect: newDialect(dialect, dbSQL),
|
||||
dataSource: source,
|
||||
}
|
||||
db.parent = db
|
||||
if err != nil {
|
||||
@ -96,6 +102,23 @@ func (s *DB) New() *DB {
|
||||
return clone
|
||||
}
|
||||
|
||||
// DataSource returns its opened dataSource:
|
||||
// gorm.Open("postgres", "host=localhost user=root dbname=test sslmode=disable password=123")
|
||||
// Then dataSource = "host=localhost user=root dbname=test sslmode=disable password=123"
|
||||
func (s *DB) DataSource() string{
|
||||
return s.dataSource
|
||||
}
|
||||
|
||||
// CopyDB returns its origin db engine whose driver is github.com/lib/pq
|
||||
func (s *DB) CopyDB() (*sql.DB,error){
|
||||
if s.copyDB!=nil {
|
||||
return s.copyDB,nil
|
||||
}else{
|
||||
return sql.Open(s.Dialect().GetName(), s.DataSource())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
type closer interface {
|
||||
Close() error
|
||||
}
|
||||
@ -462,6 +485,64 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||
return scope.Exec().db
|
||||
}
|
||||
|
||||
// CopyIn execute a big amount of insertions
|
||||
// closeAfterUsed: After finished or stopped by accident, the copyDB close itself if 'closeAfterUsed' has been set true
|
||||
// table: which table you want to operate
|
||||
// args: what you want to insert
|
||||
// columns: what column you want to operate
|
||||
// For instance:
|
||||
// db.CopyIn(false, "user", [][]string{["tom", 9],["sara", 10],["jim", 19]}, "name","age")
|
||||
// This stands for 'insert into user(name,age) values('tom', 9),('sara',10),('jim',19)'
|
||||
// or 'COPY 'user' ('name','age') FROM STDIN`'
|
||||
func (s *DB) CopyIn(closeAfterUsed bool,table string, args [][]interface{}, columns ...string) error {
|
||||
if s.Dialect().GetName() != "postgres" {
|
||||
return errors.New("CopyIn only supports postgres")
|
||||
}
|
||||
|
||||
pdb, err := s.CopyDB()
|
||||
if closeAfterUsed{
|
||||
defer pdb.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txn, err := pdb.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(pq.CopyIn(table, columns...))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _,v:= range args {
|
||||
_, err = stmt.Exec(v...)
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
}
|
||||
err = txn.Commit()
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
|
49
main_copyIn_test.go
Normal file
49
main_copyIn_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/fwhezfwhez/gorm"
|
||||
"testing"
|
||||
|
||||
)
|
||||
|
||||
func TestDB_DataSource(t *testing.T) {
|
||||
source := fmt.Sprintf("host=%s user=%s dbname=%s sslmode=%s password=%s",
|
||||
"localhost", "postgres", "test", "disable", "123")
|
||||
db, er := Open("postgres", source)
|
||||
if er != nil {
|
||||
t.Fatal(er.Error())
|
||||
}
|
||||
fmt.Println(db.DataSource())
|
||||
}
|
||||
func TestDB_CopyIn(t *testing.T) {
|
||||
source := fmt.Sprintf("host=%s user=%s dbname=%s sslmode=%s password=%s",
|
||||
"localhost", "postgres", "test", "disable", "123")
|
||||
db, er := gorm.Open("postgres", source)
|
||||
if er != nil {
|
||||
t.Fatal(er.Error())
|
||||
}
|
||||
db.Exec("create table if not exists example(name varchar, age integer)")
|
||||
var args = make([][]interface{}, 0)
|
||||
args = append(args, []interface{}{
|
||||
"tom", 9,
|
||||
}, []interface{}{
|
||||
"sara", 10,
|
||||
}, []interface{}{
|
||||
"jim", 11,
|
||||
})
|
||||
e := db.CopyIn(true, "example", args, "name", "age")
|
||||
if e != nil {
|
||||
t.Fatal(e.Error())
|
||||
}
|
||||
type Example struct{
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
var examples = make([]Example,0)
|
||||
e=db.Raw("select * from example").Find(&examples).Error
|
||||
if e!=nil {
|
||||
t.Fatal(e.Error())
|
||||
}
|
||||
fmt.Println(examples)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user