callbacks support sort with wildcard
This commit is contained in:
		
							parent
							
								
									f83b00d20d
								
							
						
					
					
						commit
						c11c939b95
					
				
							
								
								
									
										16
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -5,6 +5,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/logger" | ||||
| @ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | ||||
| 		names, sorted []string | ||||
| 		sortCallback  func(*callback) error | ||||
| 	) | ||||
| 	sort.Slice(cs, func(i, j int) bool { | ||||
| 		return cs[j].before == "*" || cs[j].after == "*" | ||||
| 	}) | ||||
| 
 | ||||
| 	for _, c := range cs { | ||||
| 		// show warning message the callback name already exists
 | ||||
| @ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | ||||
| 
 | ||||
| 	sortCallback = func(c *callback) error { | ||||
| 		if c.before != "" { // if defined before callback
 | ||||
| 			if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { | ||||
| 			if c.before == "*" && len(sorted) > 0 { | ||||
| 				if curIdx := getRIndex(sorted, c.name); curIdx == -1 { | ||||
| 					sorted = append([]string{c.name}, sorted...) | ||||
| 				} | ||||
| 			} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { | ||||
| 				if curIdx := getRIndex(sorted, c.name); curIdx == -1 { | ||||
| 					// if before callback already sorted, append current callback just after it
 | ||||
| 					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) | ||||
| @ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | ||||
| 		} | ||||
| 
 | ||||
| 		if c.after != "" { // if defined after callback
 | ||||
| 			if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { | ||||
| 			if c.after == "*" && len(sorted) > 0 { | ||||
| 				if curIdx := getRIndex(sorted, c.name); curIdx == -1 { | ||||
| 					sorted = append(sorted, c.name) | ||||
| 				} | ||||
| 			} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { | ||||
| 				if curIdx := getRIndex(sorted, c.name); curIdx == -1 { | ||||
| 					// if after callback sorted, append current callback to last
 | ||||
| 					sorted = append(sorted, c.name) | ||||
|  | ||||
							
								
								
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							| @ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB { | ||||
| 			preparedStmt := v.(*PreparedStmtDB) | ||||
| 			tx.Statement.ConnPool = &PreparedStmtDB{ | ||||
| 				ConnPool: db.Config.ConnPool, | ||||
| 				mux:      preparedStmt.mux, | ||||
| 				Mux:      preparedStmt.Mux, | ||||
| 				Stmts:    preparedStmt.Stmts, | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -9,12 +9,12 @@ import ( | ||||
| type PreparedStmtDB struct { | ||||
| 	Stmts       map[string]*sql.Stmt | ||||
| 	PreparedSQL []string | ||||
| 	mux         sync.RWMutex | ||||
| 	Mux         sync.RWMutex | ||||
| 	ConnPool | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) Close() { | ||||
| 	db.mux.Lock() | ||||
| 	db.Mux.Lock() | ||||
| 	for _, query := range db.PreparedSQL { | ||||
| 		if stmt, ok := db.Stmts[query]; ok { | ||||
| 			delete(db.Stmts, query) | ||||
| @ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	db.mux.Unlock() | ||||
| 	db.Mux.Unlock() | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { | ||||
| 	db.mux.RLock() | ||||
| 	db.Mux.RLock() | ||||
| 	if stmt, ok := db.Stmts[query]; ok { | ||||
| 		db.mux.RUnlock() | ||||
| 		db.Mux.RUnlock() | ||||
| 		return stmt, nil | ||||
| 	} | ||||
| 	db.mux.RUnlock() | ||||
| 	db.Mux.RUnlock() | ||||
| 
 | ||||
| 	db.mux.Lock() | ||||
| 	db.Mux.Lock() | ||||
| 	// double check
 | ||||
| 	if stmt, ok := db.Stmts[query]; ok { | ||||
| 		db.mux.Unlock() | ||||
| 		db.Mux.Unlock() | ||||
| 		return stmt, nil | ||||
| 	} | ||||
| 
 | ||||
| @ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { | ||||
| 		db.Stmts[query] = stmt | ||||
| 		db.PreparedSQL = append(db.PreparedSQL, query) | ||||
| 	} | ||||
| 	db.mux.Unlock() | ||||
| 	db.Mux.Unlock() | ||||
| 
 | ||||
| 	return stmt, err | ||||
| } | ||||
| @ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. | ||||
| 	if err == nil { | ||||
| 		result, err = stmt.ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			db.mux.Lock() | ||||
| 			db.Mux.Lock() | ||||
| 			stmt.Close() | ||||
| 			delete(db.Stmts, query) | ||||
| 			db.mux.Unlock() | ||||
| 			db.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return result, err | ||||
| @ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . | ||||
| 	if err == nil { | ||||
| 		rows, err = stmt.QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			db.mux.Lock() | ||||
| 			db.Mux.Lock() | ||||
| 			stmt.Close() | ||||
| 			delete(db.Stmts, query) | ||||
| 			db.mux.Unlock() | ||||
| 			db.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return rows, err | ||||
| @ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| 	if err == nil { | ||||
| 		result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.mux.Lock() | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			stmt.Close() | ||||
| 			delete(tx.PreparedStmtDB.Stmts, query) | ||||
| 			tx.PreparedStmtDB.mux.Unlock() | ||||
| 			tx.PreparedStmtDB.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return result, err | ||||
| @ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . | ||||
| 	if err == nil { | ||||
| 		rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.mux.Lock() | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			stmt.Close() | ||||
| 			delete(tx.PreparedStmtDB.Stmts, query) | ||||
| 			tx.PreparedStmtDB.mux.Unlock() | ||||
| 			tx.PreparedStmtDB.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return rows, err | ||||
|  | ||||
| @ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) { | ||||
| 			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, | ||||
| 			results:   []string{"c1", "c4", "c3"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, | ||||
| 			results:   []string{"c5", "c1", "c2", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, | ||||
| 			results:   []string{"c5", "c1", "c2", "c4", "c3"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, data := range datas { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu