Work on create callbacks
This commit is contained in:
		
							parent
							
								
									728c0d4470
								
							
						
					
					
						commit
						d52ee0aa44
					
				| @ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| ) | ||||
| 
 | ||||
| func BeforeCreate(db *gorm.DB) { | ||||
| @ -17,8 +18,14 @@ func SaveBeforeAssociations(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Create(db *gorm.DB) { | ||||
| 	db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") | ||||
| 	db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||
| 		Table: clause.Table{Table: db.Statement.Table}, | ||||
| 	}) | ||||
| 
 | ||||
| 	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING") | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	fmt.Println(err) | ||||
| 	fmt.Println(result) | ||||
| 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -55,7 +55,9 @@ func (db *DB) Omit(columns ...string) (tx *DB) { | ||||
| 
 | ||||
| func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)}) | ||||
| 	tx.Statement.AddClause(clause.Where{ | ||||
| 		AndConditions: tx.Statement.BuildCondtion(query, args...), | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -63,7 +65,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { | ||||
| func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.AddClause(clause.Where{ | ||||
| 		AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))}, | ||||
| 		AndConditions: []clause.Expression{ | ||||
| 			clause.NotConditions(tx.Statement.BuildCondtion(query, args...)), | ||||
| 		}, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @ -72,7 +76,9 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { | ||||
| func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.AddClause(clause.Where{ | ||||
| 		ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)}, | ||||
| 		ORConditions: []clause.ORConditions{ | ||||
| 			tx.Statement.BuildCondtion(query, args...), | ||||
| 		}, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
							
								
								
									
										34
									
								
								clause/insert.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								clause/insert.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | ||||
| package clause | ||||
| 
 | ||||
| type Insert struct { | ||||
| 	Table    Table | ||||
| 	Priority string | ||||
| } | ||||
| 
 | ||||
| // Name insert clause name
 | ||||
| func (insert Insert) Name() string { | ||||
| 	return "INSERT" | ||||
| } | ||||
| 
 | ||||
| // Build build insert clause
 | ||||
| func (insert Insert) Build(builder Builder) { | ||||
| 	if insert.Priority != "" { | ||||
| 		builder.Write(insert.Priority) | ||||
| 		builder.WriteByte(' ') | ||||
| 	} | ||||
| 
 | ||||
| 	builder.Write("INTO ") | ||||
| 	builder.WriteQuoted(insert.Table) | ||||
| } | ||||
| 
 | ||||
| // MergeExpression merge insert clauses
 | ||||
| func (insert Insert) MergeExpression(expr Expression) { | ||||
| 	if v, ok := expr.(Insert); ok { | ||||
| 		if insert.Priority == "" { | ||||
| 			insert.Priority = v.Priority | ||||
| 		} | ||||
| 		if insert.Table.Table == "" { | ||||
| 			insert.Table = v.Table | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										39
									
								
								clause/value.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								clause/value.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,39 @@ | ||||
| package clause | ||||
| 
 | ||||
| type Values struct { | ||||
| 	Columns []Column | ||||
| 	Values  [][]interface{} | ||||
| } | ||||
| 
 | ||||
| // Name from clause name
 | ||||
| func (Values) Name() string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // Build build from clause
 | ||||
| func (values Values) Build(builder Builder) { | ||||
| 	if len(values.Columns) > 0 { | ||||
| 		builder.WriteByte('(') | ||||
| 		for idx, column := range values.Columns { | ||||
| 			if idx > 0 { | ||||
| 				builder.WriteByte(',') | ||||
| 			} | ||||
| 			builder.WriteQuoted(column) | ||||
| 		} | ||||
| 		builder.WriteByte(')') | ||||
| 
 | ||||
| 		builder.Write(" VALUES ") | ||||
| 
 | ||||
| 		for idx, value := range values.Values { | ||||
| 			builder.WriteByte('(') | ||||
| 			if idx > 0 { | ||||
| 				builder.WriteByte(',') | ||||
| 			} | ||||
| 
 | ||||
| 			builder.Write(builder.AddVar(value...)) | ||||
| 			builder.WriteByte(')') | ||||
| 		} | ||||
| 	} else { | ||||
| 		builder.Write("DEFAULT VALUES") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										33
									
								
								dialects/postgres/postgres.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								dialects/postgres/postgres.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | ||||
| package postgres | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	_ "github.com/lib/pq" | ||||
| ) | ||||
| 
 | ||||
| type Dialector struct { | ||||
| 	DSN string | ||||
| } | ||||
| 
 | ||||
| func Open(dsn string) gorm.Dialector { | ||||
| 	return &Dialector{DSN: dsn} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db) | ||||
| 
 | ||||
| 	db.DB, err = sql.Open("postgres", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (Dialector) Migrator() gorm.Migrator { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
| 	return "?" | ||||
| } | ||||
| @ -1,29 +1,33 @@ | ||||
| package sqlite | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
| 
 | ||||
| type Dialector struct { | ||||
| 	DSN string | ||||
| } | ||||
| 
 | ||||
| func Open(dsn string) gorm.Dialector { | ||||
| 	return &Dialector{} | ||||
| 	return &Dialector{DSN: dsn} | ||||
| } | ||||
| 
 | ||||
| func (Dialector) Initialize(db *gorm.DB) error { | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db) | ||||
| 
 | ||||
| 	return nil | ||||
| 	db.DB, err = sql.Open("sqlite3", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (Dialector) Migrator() gorm.Migrator { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { | ||||
| func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
| 	return "?" | ||||
| } | ||||
|  | ||||
| @ -4,7 +4,16 @@ import ( | ||||
| 	"database/sql" | ||||
| ) | ||||
| 
 | ||||
| func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { | ||||
| // Create insert the value into database
 | ||||
| func (db *DB) Create(value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Dest = value | ||||
| 	tx.callbacks.Create().Execute(tx) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Save update value in database, if the value doesn't have primary key, will insert it
 | ||||
| func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| @ -36,32 +45,12 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Row() *sql.Row { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Rows() (*sql.Rows, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| // Scan scan value to a struct
 | ||||
| func (db *DB) Scan(dest interface{}) (tx *DB) { | ||||
| func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Create insert the value into database
 | ||||
| func (db *DB) Create(value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Save update value in database, if the value doesn't have primary key, will insert it
 | ||||
| func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| @ -78,7 +67,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) { | ||||
| func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| @ -88,16 +77,6 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
 | ||||
| func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| @ -119,6 +98,29 @@ func (db *DB) Association(column string) *Association { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Count(value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Row() *sql.Row { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Rows() (*sql.Rows, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| // Scan scan value to a struct
 | ||||
| func (db *DB) Scan(dest interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { | ||||
| 	panicked := true | ||||
| 	tx := db.Begin(opts...) | ||||
|  | ||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							| @ -2,4 +2,8 @@ module github.com/jinzhu/gorm | ||||
| 
 | ||||
| go 1.13 | ||||
| 
 | ||||
| require github.com/jinzhu/inflection v1.0.0 | ||||
| require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
| 	github.com/lib/pq v1.3.0 | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible | ||||
| ) | ||||
|  | ||||
							
								
								
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							| @ -29,6 +29,7 @@ type DB struct { | ||||
| 	Dialector | ||||
| 	Instance | ||||
| 	DB             CommonDB | ||||
| 	ClauseBuilders map[string]clause.ClauseBuilder | ||||
| 	clone          bool | ||||
| 	callbacks      *callbacks | ||||
| 	cacheStore     *sync.Map | ||||
| @ -142,6 +143,7 @@ func (db *DB) getInstance() *DB { | ||||
| 			}, | ||||
| 			Config:         db.Config, | ||||
| 			Dialector:      db.Dialector, | ||||
| 			ClauseBuilders: db.ClauseBuilders, | ||||
| 			DB:             db.DB, | ||||
| 			callbacks:      db.callbacks, | ||||
| 			cacheStore:     db.cacheStore, | ||||
|  | ||||
| @ -9,7 +9,7 @@ import ( | ||||
| type Dialector interface { | ||||
| 	Initialize(*DB) error | ||||
| 	Migrator() Migrator | ||||
| 	BindVar(stmt Statement, v interface{}) string | ||||
| 	BindVar(stmt *Statement, v interface{}) string | ||||
| } | ||||
| 
 | ||||
| // CommonDB common db interface
 | ||||
|  | ||||
							
								
								
									
										47
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										47
									
								
								statement.go
									
									
									
									
									
								
							| @ -5,6 +5,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| @ -21,7 +22,7 @@ type Instance struct { | ||||
| 	Statement    *Statement | ||||
| } | ||||
| 
 | ||||
| func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { | ||||
| func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { | ||||
| 	if len(clauses) > 0 { | ||||
| 		instance.Statement.Build(clauses...) | ||||
| 	} | ||||
| @ -29,7 +30,7 @@ func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { | ||||
| } | ||||
| 
 | ||||
| // AddError add error to instance
 | ||||
| func (inst Instance) AddError(err error) { | ||||
| func (inst *Instance) AddError(err error) { | ||||
| 	if inst.Error == nil { | ||||
| 		inst.Error = err | ||||
| 	} else { | ||||
| @ -55,11 +56,11 @@ type Statement struct { | ||||
| 
 | ||||
| // StatementOptimizer statement optimizer interface
 | ||||
| type StatementOptimizer interface { | ||||
| 	OptimizeStatement(Statement) | ||||
| 	OptimizeStatement(*Statement) | ||||
| } | ||||
| 
 | ||||
| // Write write string
 | ||||
| func (stmt Statement) Write(sql ...string) (err error) { | ||||
| func (stmt *Statement) Write(sql ...string) (err error) { | ||||
| 	for _, s := range sql { | ||||
| 		_, err = stmt.SQL.WriteString(s) | ||||
| 	} | ||||
| @ -67,12 +68,12 @@ func (stmt Statement) Write(sql ...string) (err error) { | ||||
| } | ||||
| 
 | ||||
| // Write write string
 | ||||
| func (stmt Statement) WriteByte(c byte) (err error) { | ||||
| func (stmt *Statement) WriteByte(c byte) (err error) { | ||||
| 	return stmt.SQL.WriteByte(c) | ||||
| } | ||||
| 
 | ||||
| // WriteQuoted write quoted field
 | ||||
| func (stmt Statement) WriteQuoted(field interface{}) (err error) { | ||||
| func (stmt *Statement) WriteQuoted(field interface{}) (err error) { | ||||
| 	_, err = stmt.SQL.WriteString(stmt.Quote(field)) | ||||
| 	return | ||||
| } | ||||
| @ -107,7 +108,7 @@ func (stmt Statement) Quote(field interface{}) string { | ||||
| } | ||||
| 
 | ||||
| // Write write string
 | ||||
| func (stmt Statement) AddVar(vars ...interface{}) string { | ||||
| func (stmt *Statement) AddVar(vars ...interface{}) string { | ||||
| 	var placeholders strings.Builder | ||||
| 	for idx, v := range vars { | ||||
| 		if idx > 0 { | ||||
| @ -134,7 +135,7 @@ func (stmt Statement) AddVar(vars ...interface{}) string { | ||||
| } | ||||
| 
 | ||||
| // AddClause add clause
 | ||||
| func (stmt Statement) AddClause(v clause.Interface) { | ||||
| func (stmt *Statement) AddClause(v clause.Interface) { | ||||
| 	if optimizer, ok := v.(StatementOptimizer); ok { | ||||
| 		optimizer.OptimizeStatement(stmt) | ||||
| 	} | ||||
| @ -154,6 +155,30 @@ func (stmt Statement) AddClause(v clause.Interface) { | ||||
| 	stmt.Clauses[v.Name()] = c | ||||
| } | ||||
| 
 | ||||
| // AddClauseIfNotExists add clause if not exists
 | ||||
| func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { | ||||
| 	if optimizer, ok := v.(StatementOptimizer); ok { | ||||
| 		optimizer.OptimizeStatement(stmt) | ||||
| 	} | ||||
| 
 | ||||
| 	log.Println(v.Name()) | ||||
| 	if c, ok := stmt.Clauses[v.Name()]; !ok { | ||||
| 		if namer, ok := v.(clause.OverrideNameInterface); ok { | ||||
| 			c.Name = namer.OverrideName() | ||||
| 		} else { | ||||
| 			c.Name = v.Name() | ||||
| 		} | ||||
| 
 | ||||
| 		if c.Expression != nil { | ||||
| 			v.MergeExpression(c.Expression) | ||||
| 		} | ||||
| 
 | ||||
| 		c.Expression = v | ||||
| 		stmt.Clauses[v.Name()] = c | ||||
| 		log.Println(stmt.Clauses[v.Name()]) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // BuildCondtion build condition
 | ||||
| func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { | ||||
| 	if sql, ok := query.(string); ok { | ||||
| @ -211,7 +236,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con | ||||
| } | ||||
| 
 | ||||
| // Build build sql with clauses names
 | ||||
| func (stmt Statement) Build(clauses ...string) { | ||||
| func (stmt *Statement) Build(clauses ...string) { | ||||
| 	var firstClauseWritten bool | ||||
| 
 | ||||
| 	for _, name := range clauses { | ||||
| @ -221,8 +246,12 @@ func (stmt Statement) Build(clauses ...string) { | ||||
| 			} | ||||
| 
 | ||||
| 			firstClauseWritten = true | ||||
| 			if b, ok := stmt.DB.ClauseBuilders[name]; ok { | ||||
| 				b.Build(c, stmt) | ||||
| 			} else { | ||||
| 				c.Build(stmt) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	// TODO handle named vars
 | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu