Add Before/After callbacks
This commit is contained in:
		
							parent
							
								
									fa22807e12
								
							
						
					
					
						commit
						e2a360b9fa
					
				| @ -8,8 +8,36 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BeforeCreate(db *gorm.DB) { | func BeforeCreate(db *gorm.DB) { | ||||||
| 	// before save
 | 	if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { | ||||||
| 	// before create
 | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			var ok bool | ||||||
|  | 			if db.Statement.Schema.BeforeSave { | ||||||
|  | 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.BeforeSave(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if db.Statement.Schema.BeforeCreate { | ||||||
|  | 				if i, ok := value.(gorm.BeforeCreateInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.BeforeCreate(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return ok | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func SaveBeforeAssociations(db *gorm.DB) { | func SaveBeforeAssociations(db *gorm.DB) { | ||||||
| @ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterCreate(db *gorm.DB) { | func AfterCreate(db *gorm.DB) { | ||||||
| 	// after save
 | 	if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { | ||||||
| 	// after create
 | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			var ok bool | ||||||
|  | 			if db.Statement.Schema.AfterSave { | ||||||
|  | 				if i, ok := value.(gorm.AfterSaveInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.AfterSave(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if db.Statement.Schema.AfterCreate { | ||||||
|  | 				if i, ok := value.(gorm.AfterCreateInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.AfterCreate(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return ok | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ConvertToCreateValues convert to create values
 | // ConvertToCreateValues convert to create values
 | ||||||
|  | |||||||
| @ -1,12 +1,60 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import "github.com/jinzhu/gorm" | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| func BeforeDelete(db *gorm.DB) { | func BeforeDelete(db *gorm.DB) { | ||||||
|  | 	if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { | ||||||
|  | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			if db.Statement.Schema.BeforeDelete { | ||||||
|  | 				if i, ok := value.(gorm.BeforeDeleteInterface); ok { | ||||||
|  | 					i.BeforeDelete(db) | ||||||
|  | 					return true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Delete(db *gorm.DB) { | func Delete(db *gorm.DB) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterDelete(db *gorm.DB) { | func AfterDelete(db *gorm.DB) { | ||||||
|  | 	if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { | ||||||
|  | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			if db.Statement.Schema.AfterDelete { | ||||||
|  | 				if i, ok := value.(gorm.AfterDeleteInterface); ok { | ||||||
|  | 					i.AfterDelete(db) | ||||||
|  | 					return true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/clause" | 	"github.com/jinzhu/gorm/clause" | ||||||
| ) | ) | ||||||
| @ -13,7 +15,7 @@ func Query(db *gorm.DB) { | |||||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 	_, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 	db.AddError(err) | 	db.AddError(err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -21,5 +23,26 @@ func Preload(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterQuery(db *gorm.DB) { | func AfterQuery(db *gorm.DB) { | ||||||
| 	// after find
 | 	if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { | ||||||
|  | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			if db.Statement.Schema.AfterFind { | ||||||
|  | 				if i, ok := value.(gorm.AfterFindInterface); ok { | ||||||
|  | 					i.AfterFind(db) | ||||||
|  | 					return true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,12 +1,76 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import "github.com/jinzhu/gorm" | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| func BeforeUpdate(db *gorm.DB) { | func BeforeUpdate(db *gorm.DB) { | ||||||
|  | 	if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | ||||||
|  | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			var ok bool | ||||||
|  | 			if db.Statement.Schema.BeforeSave { | ||||||
|  | 				if i, ok := value.(gorm.BeforeSaveInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.BeforeSave(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if db.Statement.Schema.BeforeUpdate { | ||||||
|  | 				if i, ok := value.(gorm.BeforeUpdateInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.BeforeUpdate(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return ok | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Update(db *gorm.DB) { | func Update(db *gorm.DB) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterUpdate(db *gorm.DB) { | func AfterUpdate(db *gorm.DB) { | ||||||
|  | 	if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { | ||||||
|  | 		callMethod := func(value interface{}) bool { | ||||||
|  | 			var ok bool | ||||||
|  | 			if db.Statement.Schema.AfterSave { | ||||||
|  | 				if i, ok := value.(gorm.AfterSaveInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.AfterSave(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if db.Statement.Schema.AfterUpdate { | ||||||
|  | 				if i, ok := value.(gorm.AfterUpdateInterface); ok { | ||||||
|  | 					ok = true | ||||||
|  | 					i.AfterUpdate(db) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return ok | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ok := callMethod(db.Statement.Dest); !ok { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 					callMethod(db.Statement.ReflectValue.Index(i).Interface()) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				callMethod(db.Statement.ReflectValue.Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BenchmarkSelect(b *testing.B) { | func BenchmarkSelect(b *testing.B) { | ||||||
| 	user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | 	user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||||
| 
 | 
 | ||||||
| 	for i := 0; i < b.N; i++ { | 	for i := 0; i < b.N; i++ { | ||||||
| 		stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | 		stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||||
| @ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func BenchmarkComplexSelect(b *testing.B) { | func BenchmarkComplexSelect(b *testing.B) { | ||||||
| 	user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | 	user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||||
| 
 | 
 | ||||||
| 	for i := 0; i < b.N; i++ { | 	for i := 0; i < b.N; i++ { | ||||||
| 		stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | 		stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, | |||||||
| 	var ( | 	var ( | ||||||
| 		buildNames    []string | 		buildNames    []string | ||||||
| 		buildNamesMap = map[string]bool{} | 		buildNamesMap = map[string]bool{} | ||||||
| 		user, _       = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | 		user, _, _    = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||||
| 		stmt          = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | 		stmt          = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -24,7 +24,7 @@ func TestExpr(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	for idx, result := range results { | 	for idx, result := range results { | ||||||
| 		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { | 		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { | ||||||
| 			user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | 			user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||||
| 			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | 			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||||
| 			clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) | 			clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) | ||||||
| 			if stmt.SQL.String() != result.Result { | 			if stmt.SQL.String() != result.Result { | ||||||
|  | |||||||
| @ -24,3 +24,39 @@ type CommonDB interface { | |||||||
| 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type BeforeCreateInterface interface { | ||||||
|  | 	BeforeCreate(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterCreateInterface interface { | ||||||
|  | 	AfterCreate(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BeforeUpdateInterface interface { | ||||||
|  | 	BeforeUpdate(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterUpdateInterface interface { | ||||||
|  | 	AfterUpdate(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BeforeSaveInterface interface { | ||||||
|  | 	BeforeSave(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterSaveInterface interface { | ||||||
|  | 	AfterSave(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BeforeDeleteInterface interface { | ||||||
|  | 	BeforeDelete(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterDeleteInterface interface { | ||||||
|  | 	AfterDelete(*DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AfterFindInterface interface { | ||||||
|  | 	AfterFind(*DB) | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										38
									
								
								schema/callbacks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								schema/callbacks_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | |||||||
|  | package schema_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type UserWithCallback struct { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (UserWithCallback) BeforeSave(*gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (UserWithCallback) AfterCreate(*gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCallback(t *testing.T) { | ||||||
|  | 	user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to parse user with callback, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, str := range []string{"BeforeSave", "AfterCreate"} { | ||||||
|  | 		if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { | ||||||
|  | 			t.Errorf("%v should be true", str) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { | ||||||
|  | 		if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { | ||||||
|  | 			t.Errorf("%v should be false", str) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -15,7 +15,7 @@ type UserCheck struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestParseCheck(t *testing.T) { | func TestParseCheck(t *testing.T) { | ||||||
| 	user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) | 	user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to parse user check, got error %v", err) | 		t.Fatalf("failed to parse user check, got error %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -14,8 +14,8 @@ import ( | |||||||
| 
 | 
 | ||||||
| func TestFieldValuerAndSetter(t *testing.T) { | func TestFieldValuerAndSetter(t *testing.T) { | ||||||
| 	var ( | 	var ( | ||||||
| 		userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) | 		userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 		user          = tests.User{ | 		user             = tests.User{ | ||||||
| 			Model: gorm.Model{ | 			Model: gorm.Model{ | ||||||
| 				ID:        10, | 				ID:        10, | ||||||
| 				CreatedAt: time.Now(), | 				CreatedAt: time.Now(), | ||||||
| @ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| func TestPointerFieldValuerAndSetter(t *testing.T) { | func TestPointerFieldValuerAndSetter(t *testing.T) { | ||||||
| 	var ( | 	var ( | ||||||
| 		userSchema, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) | 		userSchema, _, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 		name               = "pointer_field_valuer_and_setter" | 		name                  = "pointer_field_valuer_and_setter" | ||||||
| 		age           uint = 18 | 		age              uint = 18 | ||||||
| 		active             = true | 		active                = true | ||||||
| 		user               = User{ | 		user                  = User{ | ||||||
| 			Model: &gorm.Model{ | 			Model: &gorm.Model{ | ||||||
| 				ID:        10, | 				ID:        10, | ||||||
| 				CreatedAt: time.Now(), | 				CreatedAt: time.Now(), | ||||||
| @ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { | func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { | ||||||
| 	var ( | 	var ( | ||||||
| 		userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) | 		userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 		name          = "advanced_data_type_valuer_and_setter" | 		name             = "advanced_data_type_valuer_and_setter" | ||||||
| 		deletedAt     = mytime(time.Now()) | 		deletedAt        = mytime(time.Now()) | ||||||
| 		isAdmin       = mybool(false) | 		isAdmin          = mybool(false) | ||||||
| 		user          = AdvancedDataTypeUser{ | 		user             = AdvancedDataTypeUser{ | ||||||
| 			ID:           sql.NullInt64{Int64: 10, Valid: true}, | 			ID:           sql.NullInt64{Int64: 10, Valid: true}, | ||||||
| 			Name:         &sql.NullString{String: name, Valid: true}, | 			Name:         &sql.NullString{String: name, Valid: true}, | ||||||
| 			Birthday:     sql.NullTime{Time: time.Now(), Valid: true}, | 			Birthday:     sql.NullTime{Time: time.Now(), Valid: true}, | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ type UserIndex struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestParseIndex(t *testing.T) { | func TestParseIndex(t *testing.T) { | ||||||
| 	user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) | 	user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to parse user index, got error %v", err) | 		t.Fatalf("failed to parse user index, got error %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -14,20 +14,25 @@ import ( | |||||||
| var ErrUnsupportedDataType = errors.New("unsupported data type") | var ErrUnsupportedDataType = errors.New("unsupported data type") | ||||||
| 
 | 
 | ||||||
| type Schema struct { | type Schema struct { | ||||||
| 	Name                     string | 	Name                      string | ||||||
| 	ModelType                reflect.Type | 	ModelType                 reflect.Type | ||||||
| 	Table                    string | 	Table                     string | ||||||
| 	PrioritizedPrimaryField  *Field | 	PrioritizedPrimaryField   *Field | ||||||
| 	DBNames                  []string | 	DBNames                   []string | ||||||
| 	PrimaryFields            []*Field | 	PrimaryFields             []*Field | ||||||
| 	Fields                   []*Field | 	Fields                    []*Field | ||||||
| 	FieldsByName             map[string]*Field | 	FieldsByName              map[string]*Field | ||||||
| 	FieldsByDBName           map[string]*Field | 	FieldsByDBName            map[string]*Field | ||||||
| 	FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
 | 	FieldsWithDefaultDBValue  map[string]*Field // fields with default value assigned by database
 | ||||||
| 	Relationships            Relationships | 	Relationships             Relationships | ||||||
| 	err                      error | 	BeforeCreate, AfterCreate bool | ||||||
| 	namer                    Namer | 	BeforeUpdate, AfterUpdate bool | ||||||
| 	cacheStore               *sync.Map | 	BeforeDelete, AfterDelete bool | ||||||
|  | 	BeforeSave, AfterSave     bool | ||||||
|  | 	AfterFind                 bool | ||||||
|  | 	err                       error | ||||||
|  | 	namer                     Namer | ||||||
|  | 	cacheStore                *sync.Map | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (schema Schema) String() string { | func (schema Schema) String() string { | ||||||
| @ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} | ||||||
|  | 	for _, name := range callbacks { | ||||||
|  | 		if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { | ||||||
|  | 			switch methodValue.Type().String() { | ||||||
|  | 			case "func(*gorm.DB)": // TODO hack
 | ||||||
|  | 				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) | ||||||
|  | 			default: | ||||||
|  | 				logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	cacheStore.Store(modelType, schema) | 	cacheStore.Store(modelType, schema) | ||||||
| 
 | 
 | ||||||
| 	// parse relations for unidentified fields
 | 	// parse relations for unidentified fields
 | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestParseSchema(t *testing.T) { | func TestParseSchema(t *testing.T) { | ||||||
| 	user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) | 	user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to parse user, got error %v", err) | 		t.Fatalf("failed to parse user, got error %v", err) | ||||||
| 	} | 	} | ||||||
| @ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestParseSchemaWithPointerFields(t *testing.T) { | func TestParseSchemaWithPointerFields(t *testing.T) { | ||||||
| 	user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) | 	user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to parse pointer user, got error %v", err) | 		t.Fatalf("failed to parse pointer user, got error %v", err) | ||||||
| 	} | 	} | ||||||
| @ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestParseSchemaWithAdvancedDataType(t *testing.T) { | func TestParseSchemaWithAdvancedDataType(t *testing.T) { | ||||||
| 	user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) | 	user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to parse pointer user, got error %v", err) | 		t.Fatalf("failed to parse pointer user, got error %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu