Add GormValuer interface support
This commit is contained in:
		
							parent
							
								
									0d96f99499
								
							
						
					
					
						commit
						ce8853e7a6
					
				| @ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. | |||||||
| * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point | * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point | ||||||
| * Context, Prepared Statment Mode, DryRun Mode | * Context, Prepared Statment Mode, DryRun Mode | ||||||
| * Batch Insert, FindInBatches, Find To Map | * Batch Insert, FindInBatches, Find To Map | ||||||
| * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg | * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr | ||||||
| * Composite Primary Key | * Composite Primary Key | ||||||
| * Auto Migrations | * Auto Migrations | ||||||
| * Logger | * Logger | ||||||
|  | |||||||
| @ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) { | |||||||
| 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { | 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | ||||||
| 			if db.Statement.Schema.BeforeSave { | 			if db.Statement.Schema.BeforeSave { | ||||||
| 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | 				if i, ok := value.(BeforeSaveInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.BeforeSave(tx)) | 					db.AddError(i.BeforeSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.BeforeCreate { | 			if db.Statement.Schema.BeforeCreate { | ||||||
| 				if i, ok := value.(gorm.BeforeCreateInterface); ok { | 				if i, ok := value.(BeforeCreateInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.BeforeCreate(tx)) | 					db.AddError(i.BeforeCreate(tx)) | ||||||
| 				} | 				} | ||||||
| @ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) { | |||||||
| 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { | 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | ||||||
| 			if db.Statement.Schema.AfterSave { | 			if db.Statement.Schema.AfterSave { | ||||||
| 				if i, ok := value.(gorm.AfterSaveInterface); ok { | 				if i, ok := value.(AfterSaveInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.AfterSave(tx)) | 					db.AddError(i.AfterSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.AfterCreate { | 			if db.Statement.Schema.AfterCreate { | ||||||
| 				if i, ok := value.(gorm.AfterCreateInterface); ok { | 				if i, ok := value.(AfterCreateInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.AfterCreate(tx)) | 					db.AddError(i.AfterCreate(tx)) | ||||||
| 				} | 				} | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ import ( | |||||||
| func BeforeDelete(db *gorm.DB) { | func BeforeDelete(db *gorm.DB) { | ||||||
| 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { | 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | ||||||
| 			if i, ok := value.(gorm.BeforeDeleteInterface); ok { | 			if i, ok := value.(BeforeDeleteInterface); ok { | ||||||
| 				db.AddError(i.BeforeDelete(tx)) | 				db.AddError(i.BeforeDelete(tx)) | ||||||
| 				return true | 				return true | ||||||
| 			} | 			} | ||||||
| @ -75,7 +75,7 @@ func Delete(db *gorm.DB) { | |||||||
| func AfterDelete(db *gorm.DB) { | func AfterDelete(db *gorm.DB) { | ||||||
| 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { | 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | ||||||
| 			if i, ok := value.(gorm.AfterDeleteInterface); ok { | 			if i, ok := value.(AfterDeleteInterface); ok { | ||||||
| 				db.AddError(i.AfterDelete(tx)) | 				db.AddError(i.AfterDelete(tx)) | ||||||
| 				return true | 				return true | ||||||
| 			} | 			} | ||||||
|  | |||||||
							
								
								
									
										39
									
								
								callbacks/interfaces.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								callbacks/interfaces.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,39 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import "gorm.io/gorm" | ||||||
|  | 
 | ||||||
|  | type BeforeCreateInterface interface { | ||||||
|  | 	BeforeCreate(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterCreateInterface interface { | ||||||
|  | 	AfterCreate(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BeforeUpdateInterface interface { | ||||||
|  | 	BeforeUpdate(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterUpdateInterface interface { | ||||||
|  | 	AfterUpdate(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BeforeSaveInterface interface { | ||||||
|  | 	BeforeSave(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterSaveInterface interface { | ||||||
|  | 	AfterSave(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BeforeDeleteInterface interface { | ||||||
|  | 	BeforeDelete(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterDeleteInterface interface { | ||||||
|  | 	AfterDelete(*gorm.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterFindInterface interface { | ||||||
|  | 	AfterFind(*gorm.DB) error | ||||||
|  | } | ||||||
| @ -214,7 +214,7 @@ func Preload(db *gorm.DB) { | |||||||
| func AfterQuery(db *gorm.DB) { | func AfterQuery(db *gorm.DB) { | ||||||
| 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { | 	if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | ||||||
| 			if i, ok := value.(gorm.AfterFindInterface); ok { | 			if i, ok := value.(AfterFindInterface); ok { | ||||||
| 				db.AddError(i.AfterFind(tx)) | 				db.AddError(i.AfterFind(tx)) | ||||||
| 				return true | 				return true | ||||||
| 			} | 			} | ||||||
|  | |||||||
| @ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) { | |||||||
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | ||||||
| 			if db.Statement.Schema.BeforeSave { | 			if db.Statement.Schema.BeforeSave { | ||||||
| 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | 				if i, ok := value.(BeforeSaveInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.BeforeSave(tx)) | 					db.AddError(i.BeforeSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.BeforeUpdate { | 			if db.Statement.Schema.BeforeUpdate { | ||||||
| 				if i, ok := value.(gorm.BeforeUpdateInterface); ok { | 				if i, ok := value.(BeforeUpdateInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.BeforeUpdate(tx)) | 					db.AddError(i.BeforeUpdate(tx)) | ||||||
| 				} | 				} | ||||||
| @ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) { | |||||||
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { | 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | ||||||
| 			if db.Statement.Schema.AfterSave { | 			if db.Statement.Schema.AfterSave { | ||||||
| 				if i, ok := value.(gorm.AfterSaveInterface); ok { | 				if i, ok := value.(AfterSaveInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.AfterSave(tx)) | 					db.AddError(i.AfterSave(tx)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema.AfterUpdate { | 			if db.Statement.Schema.AfterUpdate { | ||||||
| 				if i, ok := value.(gorm.AfterUpdateInterface); ok { | 				if i, ok := value.(AfterUpdateInterface); ok { | ||||||
| 					called = true | 					called = true | ||||||
| 					db.AddError(i.AfterUpdate(tx)) | 					db.AddError(i.AfterUpdate(tx)) | ||||||
| 				} | 				} | ||||||
|  | |||||||
| @ -53,38 +53,7 @@ type TxCommitter interface { | |||||||
| 	Rollback() error | 	Rollback() error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type BeforeCreateInterface interface { | // Valuer gorm valuer interface
 | ||||||
| 	BeforeCreate(*DB) error | type Valuer interface { | ||||||
| } | 	GormValue(context.Context, *DB) clause.Expr | ||||||
| 
 |  | ||||||
| type AfterCreateInterface interface { |  | ||||||
| 	AfterCreate(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BeforeUpdateInterface interface { |  | ||||||
| 	BeforeUpdate(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AfterUpdateInterface interface { |  | ||||||
| 	AfterUpdate(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BeforeSaveInterface interface { |  | ||||||
| 	BeforeSave(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AfterSaveInterface interface { |  | ||||||
| 	AfterSave(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BeforeDeleteInterface interface { |  | ||||||
| 	BeforeDelete(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AfterDeleteInterface interface { |  | ||||||
| 	AfterDelete(*DB) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AfterFindInterface interface { |  | ||||||
| 	AfterFind(*DB) error |  | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| package schema | package schema | ||||||
| 
 | 
 | ||||||
| import "gorm.io/gorm/clause" | import ( | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| type GormDataTypeInterface interface { | type GormDataTypeInterface interface { | ||||||
| 	GormDataType() string | 	GormDataType() string | ||||||
|  | |||||||
| @ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | |||||||
| 			stmt.Vars = append(stmt.Vars, v.Value) | 			stmt.Vars = append(stmt.Vars, v.Value) | ||||||
| 		case clause.Column, clause.Table: | 		case clause.Column, clause.Table: | ||||||
| 			stmt.QuoteTo(writer, v) | 			stmt.QuoteTo(writer, v) | ||||||
|  | 		case Valuer: | ||||||
|  | 			stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) | ||||||
| 		case clause.Expr: | 		case clause.Expr: | ||||||
| 			var varStr strings.Builder | 			var varStr strings.Builder | ||||||
| 			var sql = v.SQL | 			var sql = v.SQL | ||||||
|  | |||||||
| @ -1,16 +1,20 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"regexp" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) { | |||||||
| type NullString struct { | type NullString struct { | ||||||
| 	sql.NullString | 	sql.NullString | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type Point struct { | ||||||
|  | 	X, Y int | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (point *Point) Scan(v interface{}) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (point Point) GormDataType() string { | ||||||
|  | 	return "geo" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { | ||||||
|  | 	return clause.Expr{ | ||||||
|  | 		SQL:  "ST_PointFromText(?)", | ||||||
|  | 		Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGORMValuer(t *testing.T) { | ||||||
|  | 	type UserWithPoint struct { | ||||||
|  | 		Name  string | ||||||
|  | 		Point Point | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	dryRunDB := DB.Session(&gorm.Session{DryRun: true}) | ||||||
|  | 
 | ||||||
|  | 	stmt := dryRunDB.Create(&UserWithPoint{ | ||||||
|  | 		Name:  "jinzhu", | ||||||
|  | 		Point: Point{X: 100, Y: 100}, | ||||||
|  | 	}).Statement | ||||||
|  | 
 | ||||||
|  | 	if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { | ||||||
|  | 		t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { | ||||||
|  | 		t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { | ||||||
|  | 		t.Errorf("generated vars is not equal, got %v", stmt.Vars) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu