Use *gorm.DB to replace gorm.DB
This commit is contained in:
		
							parent
							
								
									2a0c3e39f2
								
							
						
					
					
						commit
						9e8a4db36b
					
				| @ -90,7 +90,6 @@ func (p *processor) Execute(db *DB) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if stmt := db.Statement; stmt != nil { | 	if stmt := db.Statement; stmt != nil { | ||||||
| 		db.Error = stmt.Error |  | ||||||
| 		db.RowsAffected = stmt.RowsAffected | 		db.RowsAffected = stmt.RowsAffected | ||||||
| 
 | 
 | ||||||
| 		db.Logger.Trace(curTime, func() (string, int64) { | 		db.Logger.Trace(curTime, func() (string, int64) { | ||||||
|  | |||||||
| @ -13,14 +13,14 @@ import ( | |||||||
| //    db.Model(&User{}).Update("name", "hello")
 | //    db.Model(&User{}).Update("name", "hello")
 | ||||||
| //    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | //    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | ||||||
| //    db.Model(&user).Update("name", "hello")
 | //    db.Model(&user).Update("name", "hello")
 | ||||||
| func (db DB) Model(value interface{}) (tx DB) { | func (db *DB) Model(value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Model = value | 	tx.Statement.Model = value | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Clauses Add clauses
 | // Clauses Add clauses
 | ||||||
| func (db DB) Clauses(conds ...clause.Expression) (tx DB) { | func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	var whereConds []interface{} | 	var whereConds []interface{} | ||||||
| 
 | 
 | ||||||
| @ -39,14 +39,14 @@ func (db DB) Clauses(conds ...clause.Expression) (tx DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Table specify the table you would like to run db operations
 | // Table specify the table you would like to run db operations
 | ||||||
| func (db DB) Table(name string) (tx DB) { | func (db *DB) Table(name string) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Table = name | 	tx.Statement.Table = name | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Select specify fields that you want when querying, creating, updating
 | // Select specify fields that you want when querying, creating, updating
 | ||||||
| func (db DB) Select(query interface{}, args ...interface{}) (tx DB) { | func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 
 | 
 | ||||||
| 	switch v := query.(type) { | 	switch v := query.(type) { | ||||||
| @ -97,7 +97,7 @@ func (db DB) Select(query interface{}, args ...interface{}) (tx DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Omit specify fields that you want to ignore when creating, updating and querying
 | // Omit specify fields that you want to ignore when creating, updating and querying
 | ||||||
| func (db DB) Omit(columns ...string) (tx DB) { | func (db *DB) Omit(columns ...string) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 
 | 
 | ||||||
| 	if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { | 	if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { | ||||||
| @ -108,21 +108,21 @@ func (db DB) Omit(columns ...string) (tx DB) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Where(query interface{}, args ...interface{}) (tx DB) { | func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) | 	tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Not add NOT condition
 | // Not add NOT condition
 | ||||||
| func (db DB) Not(query interface{}, args ...interface{}) (tx DB) { | func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) | 	tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Or add OR conditions
 | // Or add OR conditions
 | ||||||
| func (db DB) Or(query interface{}, args ...interface{}) (tx DB) { | func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) | 	tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) | ||||||
| 	return | 	return | ||||||
| @ -131,13 +131,13 @@ func (db DB) Or(query interface{}, args ...interface{}) (tx DB) { | |||||||
| // Joins specify Joins conditions
 | // Joins specify Joins conditions
 | ||||||
| //     db.Joins("Account").Find(&user)
 | //     db.Joins("Account").Find(&user)
 | ||||||
| //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | ||||||
| func (db DB) Joins(query string, args ...interface{}) (tx DB) { | func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Group specify the group method on the find
 | // Group specify the group method on the find
 | ||||||
| func (db DB) Group(name string) (tx DB) { | func (db *DB) Group(name string) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.GroupBy{ | 	tx.Statement.AddClause(clause.GroupBy{ | ||||||
| 		Columns: []clause.Column{{Name: name}}, | 		Columns: []clause.Column{{Name: name}}, | ||||||
| @ -146,7 +146,7 @@ func (db DB) Group(name string) (tx DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Having specify HAVING conditions for GROUP BY
 | // Having specify HAVING conditions for GROUP BY
 | ||||||
| func (db DB) Having(query interface{}, args ...interface{}) (tx DB) { | func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.GroupBy{ | 	tx.Statement.AddClause(clause.GroupBy{ | ||||||
| 		Having: tx.Statement.BuildCondtion(query, args...), | 		Having: tx.Statement.BuildCondtion(query, args...), | ||||||
| @ -157,7 +157,7 @@ func (db DB) Having(query interface{}, args ...interface{}) (tx DB) { | |||||||
| // Order specify order when retrieve records from database
 | // Order specify order when retrieve records from database
 | ||||||
| //     db.Order("name DESC")
 | //     db.Order("name DESC")
 | ||||||
| //     db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
 | //     db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
 | ||||||
| func (db DB) Order(value interface{}) (tx DB) { | func (db *DB) Order(value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 
 | 
 | ||||||
| 	switch v := value.(type) { | 	switch v := value.(type) { | ||||||
| @ -176,14 +176,14 @@ func (db DB) Order(value interface{}) (tx DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Limit specify the number of records to be retrieved
 | // Limit specify the number of records to be retrieved
 | ||||||
| func (db DB) Limit(limit int) (tx DB) { | func (db *DB) Limit(limit int) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.Limit{Limit: limit}) | 	tx.Statement.AddClause(clause.Limit{Limit: limit}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Offset specify the number of records to skip before starting to return the records
 | // Offset specify the number of records to skip before starting to return the records
 | ||||||
| func (db DB) Offset(offset int) (tx DB) { | func (db *DB) Offset(offset int) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.AddClause(clause.Limit{Offset: offset}) | 	tx.Statement.AddClause(clause.Limit{Offset: offset}) | ||||||
| 	return | 	return | ||||||
| @ -201,7 +201,7 @@ func (db DB) Offset(offset int) (tx DB) { | |||||||
| //     }
 | //     }
 | ||||||
| //
 | //
 | ||||||
| //     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
 | //     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
 | ||||||
| func (db DB) Scopes(funcs ...func(DB) DB) DB { | func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { | ||||||
| 	for _, f := range funcs { | 	for _, f := range funcs { | ||||||
| 		db = f(db) | 		db = f(db) | ||||||
| 	} | 	} | ||||||
| @ -210,27 +210,27 @@ func (db DB) Scopes(funcs ...func(DB) DB) DB { | |||||||
| 
 | 
 | ||||||
| // Preload preload associations with given conditions
 | // Preload preload associations with given conditions
 | ||||||
| //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | ||||||
| func (db DB) Preload(column string, conditions ...interface{}) (tx DB) { | func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Assign(attrs ...interface{}) (tx DB) { | func (db *DB) Assign(attrs ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Attrs(attrs ...interface{}) (tx DB) { | func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Unscoped() (tx DB) { | func (db *DB) Unscoped() (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Raw(sql string, values ...interface{}) (tx DB) { | func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.SQL = strings.Builder{} | 	tx.Statement.SQL = strings.Builder{} | ||||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	DB  gorm.DB | 	DB  *gorm.DB | ||||||
| 	err error | 	err error | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -23,9 +23,9 @@ func init() { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestCURD(t *testing.T) { | func TestCURD(t *testing.T) { | ||||||
| 	tests.RunTestsSuit(t, &DB) | 	tests.RunTestsSuit(t, DB) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestMigrate(t *testing.T) { | func TestMigrate(t *testing.T) { | ||||||
| 	tests.TestMigrate(t, &DB) | 	tests.TestMigrate(t, DB) | ||||||
| } | } | ||||||
|  | |||||||
| @ -9,15 +9,15 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Create insert the value into database
 | // Create insert the value into database
 | ||||||
| func (db DB) Create(value interface{}) (tx DB) { | func (db *DB) Create(value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = value | 	tx.Statement.Dest = value | ||||||
| 	tx.callbacks.Create().Execute(&tx) | 	tx.callbacks.Create().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Save update value in database, if the value doesn't have primary key, will insert it
 | // Save update value in database, if the value doesn't have primary key, will insert it
 | ||||||
| func (db DB) Save(value interface{}) (tx DB) { | func (db *DB) Save(value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = value | 	tx.Statement.Dest = value | ||||||
| 
 | 
 | ||||||
| @ -26,7 +26,7 @@ func (db DB) Save(value interface{}) (tx DB) { | |||||||
| 		reflectValue := reflect.ValueOf(value) | 		reflectValue := reflect.ValueOf(value) | ||||||
| 		for idx, pf := range tx.Statement.Schema.PrimaryFields { | 		for idx, pf := range tx.Statement.Schema.PrimaryFields { | ||||||
| 			if pv, isZero := pf.ValueOf(reflectValue); isZero { | 			if pv, isZero := pf.ValueOf(reflectValue); isZero { | ||||||
| 				tx.callbacks.Create().Execute(&tx) | 				tx.callbacks.Create().Execute(tx) | ||||||
| 				where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} | 				where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @ -38,12 +38,12 @@ func (db DB) Save(value interface{}) (tx DB) { | |||||||
| 	if len(tx.Statement.Selects) == 0 { | 	if len(tx.Statement.Selects) == 0 { | ||||||
| 		tx.Statement.Selects = []string{"*"} | 		tx.Statement.Selects = []string{"*"} | ||||||
| 	} | 	} | ||||||
| 	tx.callbacks.Update().Execute(&tx) | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // First find first record that match given conditions, order by primary key
 | // First find first record that match given conditions, order by primary key
 | ||||||
| func (db DB) First(out interface{}, conds ...interface{}) (tx DB) { | func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||||
| 	}) | 	}) | ||||||
| @ -52,24 +52,24 @@ func (db DB) First(out interface{}, conds ...interface{}) (tx DB) { | |||||||
| 	} | 	} | ||||||
| 	tx.Statement.RaiseErrorOnNotFound = true | 	tx.Statement.RaiseErrorOnNotFound = true | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(&tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Take return a record that match given conditions, the order will depend on the database implementation
 | // Take return a record that match given conditions, the order will depend on the database implementation
 | ||||||
| func (db DB) Take(out interface{}, conds ...interface{}) (tx DB) { | func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance().Limit(1) | 	tx = db.getInstance().Limit(1) | ||||||
| 	if len(conds) > 0 { | 	if len(conds) > 0 { | ||||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
| 	} | 	} | ||||||
| 	tx.Statement.RaiseErrorOnNotFound = true | 	tx.Statement.RaiseErrorOnNotFound = true | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(&tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Last find last record that match given conditions, order by primary key
 | // Last find last record that match given conditions, order by primary key
 | ||||||
| func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) { | func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||||
| 		Desc:   true, | 		Desc:   true, | ||||||
| @ -79,101 +79,101 @@ func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) { | |||||||
| 	} | 	} | ||||||
| 	tx.Statement.RaiseErrorOnNotFound = true | 	tx.Statement.RaiseErrorOnNotFound = true | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(&tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Find find records that match given conditions
 | // Find find records that match given conditions
 | ||||||
| func (db DB) Find(out interface{}, conds ...interface{}) (tx DB) { | func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	if len(conds) > 0 { | 	if len(conds) > 0 { | ||||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
| 	} | 	} | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(&tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) FirstOrInit(out interface{}, where ...interface{}) (tx DB) { | func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) FirstOrCreate(out interface{}, where ...interface{}) (tx DB) { | func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||||
| func (db DB) Update(column string, value interface{}) (tx DB) { | func (db *DB) Update(column string, value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = map[string]interface{}{column: value} | 	tx.Statement.Dest = map[string]interface{}{column: value} | ||||||
| 	tx.callbacks.Update().Execute(&tx) | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||||
| func (db DB) Updates(values interface{}) (tx DB) { | func (db *DB) Updates(values interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = values | 	tx.Statement.Dest = values | ||||||
| 	tx.callbacks.Update().Execute(&tx) | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) UpdateColumn(column string, value interface{}) (tx DB) { | func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = map[string]interface{}{column: value} | 	tx.Statement.Dest = map[string]interface{}{column: value} | ||||||
| 	tx.callbacks.Update().Execute(&tx) | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) UpdateColumns(values interface{}) (tx DB) { | func (db *DB) UpdateColumns(values interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = values | 	tx.Statement.Dest = values | ||||||
| 	tx.callbacks.Update().Execute(&tx) | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
 | // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
 | ||||||
| func (db DB) Delete(value interface{}, conds ...interface{}) (tx DB) { | func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	if len(conds) > 0 { | 	if len(conds) > 0 { | ||||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
| 	} | 	} | ||||||
| 	tx.Statement.Dest = value | 	tx.Statement.Dest = value | ||||||
| 	tx.callbacks.Delete().Execute(&tx) | 	tx.callbacks.Delete().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Count(value interface{}) (tx DB) { | func (db *DB) Count(value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Row() *sql.Row { | func (db *DB) Row() *sql.Row { | ||||||
| 	tx := db.getInstance() | 	tx := db.getInstance() | ||||||
| 	tx.callbacks.Row().Execute(&tx) | 	tx.callbacks.Row().Execute(tx) | ||||||
| 	return tx.Statement.Dest.(*sql.Row) | 	return tx.Statement.Dest.(*sql.Row) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) Rows() (*sql.Rows, error) { | func (db *DB) Rows() (*sql.Rows, error) { | ||||||
| 	tx := db.Set("rows", true) | 	tx := db.Set("rows", true) | ||||||
| 	tx.callbacks.Row().Execute(&tx) | 	tx.callbacks.Row().Execute(tx) | ||||||
| 	return tx.Statement.Dest.(*sql.Rows), tx.Error | 	return tx.Statement.Dest.(*sql.Rows), tx.Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Scan scan value to a struct
 | // Scan scan value to a struct
 | ||||||
| func (db DB) Scan(dest interface{}) (tx DB) { | func (db *DB) Scan(dest interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) ScanRows(rows *sql.Rows, result interface{}) error { | func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | ||||||
| func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err error) { | func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { | ||||||
| 	panicked := true | 	panicked := true | ||||||
| 	tx := db.Begin(opts...) | 	tx := db.Begin(opts...) | ||||||
| 	defer func() { | 	defer func() { | ||||||
| @ -194,7 +194,7 @@ func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err erro | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Begin begins a transaction
 | // Begin begins a transaction
 | ||||||
| func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { | func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { | 	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { | ||||||
| 		var opt *sql.TxOptions | 		var opt *sql.TxOptions | ||||||
| @ -213,7 +213,7 @@ func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Commit commit a transaction
 | // Commit commit a transaction
 | ||||||
| func (db DB) Commit() DB { | func (db *DB) Commit() *DB { | ||||||
| 	if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { | 	if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { | ||||||
| 		db.AddError(comminter.Commit()) | 		db.AddError(comminter.Commit()) | ||||||
| 	} else { | 	} else { | ||||||
| @ -223,7 +223,7 @@ func (db DB) Commit() DB { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Rollback rollback a transaction
 | // Rollback rollback a transaction
 | ||||||
| func (db DB) Rollback() DB { | func (db *DB) Rollback() *DB { | ||||||
| 	if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { | 	if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { | ||||||
| 		db.AddError(comminter.Rollback()) | 		db.AddError(comminter.Rollback()) | ||||||
| 	} else { | 	} else { | ||||||
| @ -233,10 +233,10 @@ func (db DB) Rollback() DB { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Exec execute raw sql
 | // Exec execute raw sql
 | ||||||
| func (db DB) Exec(sql string, values ...interface{}) (tx DB) { | func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.SQL = strings.Builder{} | 	tx.Statement.SQL = strings.Builder{} | ||||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||||
| 	tx.callbacks.Raw().Execute(&tx) | 	tx.callbacks.Raw().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										35
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								gorm.go
									
									
									
									
									
								
							| @ -2,6 +2,7 @@ package gorm | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| @ -51,7 +52,7 @@ type Session struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Open initialize db session based on dialector
 | // Open initialize db session based on dialector
 | ||||||
| func Open(dialector Dialector, config *Config) (db DB, err error) { | func Open(dialector Dialector, config *Config) (db *DB, err error) { | ||||||
| 	if config == nil { | 	if config == nil { | ||||||
| 		config = &Config{} | 		config = &Config{} | ||||||
| 	} | 	} | ||||||
| @ -87,21 +88,21 @@ func Open(dialector Dialector, config *Config) (db DB, err error) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	db = DB{ | 	db = &DB{ | ||||||
| 		Config: config, | 		Config: config, | ||||||
| 		clone:  true, | 		clone:  true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	db.callbacks = initializeCallbacks(&db) | 	db.callbacks = initializeCallbacks(db) | ||||||
| 
 | 
 | ||||||
| 	if dialector != nil { | 	if dialector != nil { | ||||||
| 		err = dialector.Initialize(&db) | 		err = dialector.Initialize(db) | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Session create new db session
 | // Session create new db session
 | ||||||
| func (db DB) Session(config *Session) DB { | func (db *DB) Session(config *Session) *DB { | ||||||
| 	var ( | 	var ( | ||||||
| 		tx       = db.getInstance() | 		tx       = db.getInstance() | ||||||
| 		txConfig = *tx.Config | 		txConfig = *tx.Config | ||||||
| @ -125,24 +126,24 @@ func (db DB) Session(config *Session) DB { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // WithContext change current instance db's context to ctx
 | // WithContext change current instance db's context to ctx
 | ||||||
| func (db DB) WithContext(ctx context.Context) DB { | func (db *DB) WithContext(ctx context.Context) *DB { | ||||||
| 	return db.Session(&Session{Context: ctx}) | 	return db.Session(&Session{Context: ctx}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Debug start debug mode
 | // Debug start debug mode
 | ||||||
| func (db DB) Debug() (tx DB) { | func (db *DB) Debug() (tx *DB) { | ||||||
| 	return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) | 	return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Set store value with key into current db instance's context
 | // Set store value with key into current db instance's context
 | ||||||
| func (db DB) Set(key string, value interface{}) DB { | func (db *DB) Set(key string, value interface{}) *DB { | ||||||
| 	tx := db.getInstance() | 	tx := db.getInstance() | ||||||
| 	tx.Statement.Settings.Store(key, value) | 	tx.Statement.Settings.Store(key, value) | ||||||
| 	return tx | 	return tx | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Get get value with key from current db instance's context
 | // Get get value with key from current db instance's context
 | ||||||
| func (db DB) Get(key string) (interface{}, bool) { | func (db *DB) Get(key string) (interface{}, bool) { | ||||||
| 	if db.Statement != nil { | 	if db.Statement != nil { | ||||||
| 		return db.Statement.Settings.Load(key) | 		return db.Statement.Settings.Load(key) | ||||||
| 	} | 	} | ||||||
| @ -150,28 +151,32 @@ func (db DB) Get(key string) (interface{}, bool) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Callback returns callback manager
 | // Callback returns callback manager
 | ||||||
| func (db DB) Callback() *callbacks { | func (db *DB) Callback() *callbacks { | ||||||
| 	return db.callbacks | 	return db.callbacks | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AutoMigrate run auto migration for given models
 | // AutoMigrate run auto migration for given models
 | ||||||
| func (db DB) AutoMigrate(dst ...interface{}) error { | func (db *DB) AutoMigrate(dst ...interface{}) error { | ||||||
| 	return db.Migrator().AutoMigrate(dst...) | 	return db.Migrator().AutoMigrate(dst...) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AddError add error to db
 | // AddError add error to db
 | ||||||
| func (db DB) AddError(err error) { | func (db *DB) AddError(err error) { | ||||||
| 	db.Statement.AddError(err) | 	if db.Error == nil { | ||||||
|  | 		db.Error = err | ||||||
|  | 	} else if err != nil { | ||||||
|  | 		db.Error = fmt.Errorf("%v; %w", db.Error, err) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db DB) getInstance() DB { | func (db *DB) getInstance() *DB { | ||||||
| 	if db.clone { | 	if db.clone { | ||||||
| 		stmt := db.Config.statementPool.Get().(*Statement) | 		stmt := db.Config.statementPool.Get().(*Statement) | ||||||
| 		if db.Statement != nil { | 		if db.Statement != nil { | ||||||
| 			stmt.Context = db.Statement.Context | 			stmt.Context = db.Statement.Context | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return DB{Config: db.Config, Statement: stmt} | 		return &DB{Config: db.Config, Statement: stmt} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return db | 	return db | ||||||
|  | |||||||
| @ -27,7 +27,7 @@ type Config struct { | |||||||
| func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { | func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { | ||||||
| 	stmt := m.DB.Statement | 	stmt := m.DB.Statement | ||||||
| 	if stmt == nil { | 	if stmt == nil { | ||||||
| 		stmt = &gorm.Statement{DB: *m.DB} | 		stmt = &gorm.Statement{DB: m.DB} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := stmt.Parse(value); err != nil { | 	if err := stmt.Parse(value); err != nil { | ||||||
| @ -496,7 +496,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i | |||||||
| 
 | 
 | ||||||
| 	parseDependence := func(value interface{}, addToList bool) { | 	parseDependence := func(value interface{}, addToList bool) { | ||||||
| 		dep := Dependency{ | 		dep := Dependency{ | ||||||
| 			Statement: &gorm.Statement{DB: *m.DB, Dest: value}, | 			Statement: &gorm.Statement{DB: m.DB, Dest: value}, | ||||||
| 		} | 		} | ||||||
| 		dep.Parse(value) | 		dep.Parse(value) | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								statement.go
									
									
									
									
									
								
							| @ -16,6 +16,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| // Statement statement
 | // Statement statement
 | ||||||
| type Statement struct { | type Statement struct { | ||||||
|  | 	*DB | ||||||
| 	Table                string | 	Table                string | ||||||
| 	Model                interface{} | 	Model                interface{} | ||||||
| 	Dest                 interface{} | 	Dest                 interface{} | ||||||
| @ -25,7 +26,6 @@ type Statement struct { | |||||||
| 	Omits                []string // omit columns
 | 	Omits                []string // omit columns
 | ||||||
| 	Settings             sync.Map | 	Settings             sync.Map | ||||||
| 	ConnPool             ConnPool | 	ConnPool             ConnPool | ||||||
| 	DB                   DB |  | ||||||
| 	Schema               *schema.Schema | 	Schema               *schema.Schema | ||||||
| 	Context              context.Context | 	Context              context.Context | ||||||
| 	Error                error | 	Error                error | ||||||
| @ -219,14 +219,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con | |||||||
| 	return conditions | 	return conditions | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *Statement) AddError(err error) { |  | ||||||
| 	if stmt.Error == nil { |  | ||||||
| 		stmt.Error = err |  | ||||||
| 	} else if err != nil { |  | ||||||
| 		stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Build build sql with clauses names
 | // Build build sql with clauses names
 | ||||||
| func (stmt *Statement) Build(clauses ...string) { | func (stmt *Statement) Build(clauses ...string) { | ||||||
| 	var firstClauseWritten bool | 	var firstClauseWritten bool | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu