Add permission check when create associations
This commit is contained in:
		
							parent
							
								
									345ff7577c
								
							
						
					
					
						commit
						56ca9a87e0
					
				
							
								
								
									
										72
									
								
								callbacks/associations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								callbacks/associations.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | 	"github.com/jinzhu/gorm/utils" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func SaveBeforeAssociations(db *gorm.DB) { | ||||||
|  | 	if db.Statement.Schema != nil { | ||||||
|  | 		for _, rel := range db.Statement.Schema.Relationships.BelongsTo { | ||||||
|  | 			creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) | ||||||
|  | 
 | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice: | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { | ||||||
|  | 					f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) | ||||||
|  | 
 | ||||||
|  | 					_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) | ||||||
|  | 
 | ||||||
|  | 					if isZero && creatable { | ||||||
|  | 						if f.Kind() == reflect.Ptr { | ||||||
|  | 							db.Session(&gorm.Session{}).Create(f.Interface()) | ||||||
|  | 						} else { | ||||||
|  | 							db.Session(&gorm.Session{}).Create(f.Addr().Interface()) | ||||||
|  | 						} | ||||||
|  | 					} else if !isZero && updatable { | ||||||
|  | 						if f.Kind() == reflect.Ptr { | ||||||
|  | 							db.Session(&gorm.Session{}).Save(f.Interface()) | ||||||
|  | 						} else { | ||||||
|  | 							db.Session(&gorm.Session{}).Save(f.Addr().Interface()) | ||||||
|  | 						} | ||||||
|  | 					} else { | ||||||
|  | 						continue | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					if saveRef { | ||||||
|  | 						for _, ref := range rel.References { | ||||||
|  | 							if !ref.OwnPrimaryKey { | ||||||
|  | 								fv, _ := ref.PrimaryKey.ValueOf(f) | ||||||
|  | 								ref.ForeignKey.Set(db.Statement.ReflectValue, fv) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { | ||||||
|  | 	creatable := field.Creatable | ||||||
|  | 	updatable := field.Updatable | ||||||
|  | 	saveRef := true | ||||||
|  | 
 | ||||||
|  | 	if value, ok := db.Get("gorm:association_autocreate"); creatable && ok { | ||||||
|  | 		creatable = utils.CheckTruth(value) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok { | ||||||
|  | 		updatable = utils.CheckTruth(value) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if value, ok := db.Get("gorm:association_save_reference"); ok { | ||||||
|  | 		saveRef = utils.CheckTruth(value) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return creatable, updatable, saveRef | ||||||
|  | } | ||||||
| @ -41,32 +41,6 @@ func BeforeCreate(db *gorm.DB) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func SaveBeforeAssociations(db *gorm.DB) { |  | ||||||
| 	if db.Statement.Schema != nil { |  | ||||||
| 		for _, rel := range db.Statement.Schema.Relationships.BelongsTo { |  | ||||||
| 			switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 			case reflect.Slice: |  | ||||||
| 			case reflect.Struct: |  | ||||||
| 				if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { |  | ||||||
| 					f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) |  | ||||||
| 					if f.Kind() == reflect.Ptr { |  | ||||||
| 						db.Session(&gorm.Session{}).Create(f.Interface()) |  | ||||||
| 					} else { |  | ||||||
| 						db.Session(&gorm.Session{}).Create(f.Addr().Interface()) |  | ||||||
| 					} |  | ||||||
| 
 |  | ||||||
| 					for _, ref := range rel.References { |  | ||||||
| 						if !ref.OwnPrimaryKey { |  | ||||||
| 							fv, _ := ref.PrimaryKey.ValueOf(f) |  | ||||||
| 							ref.ForeignKey.Set(db.Statement.ReflectValue, fv) |  | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Create(config *Config) func(db *gorm.DB) { | func Create(config *Config) func(db *gorm.DB) { | ||||||
| 	if config.WithReturning { | 	if config.WithReturning { | ||||||
| 		return CreateWithReturning | 		return CreateWithReturning | ||||||
|  | |||||||
| @ -21,7 +21,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = value | 	tx.Statement.Dest = value | ||||||
| 
 | 
 | ||||||
| 	if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil { | 	if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||||
| 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | ||||||
| 		reflectValue := reflect.ValueOf(value) | 		reflectValue := reflect.ValueOf(value) | ||||||
| 		for idx, pf := range tx.Statement.Schema.PrimaryFields { | 		for idx, pf := range tx.Statement.Schema.PrimaryFields { | ||||||
| @ -35,9 +35,6 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| 		tx.Statement.AddClause(where) | 		tx.Statement.AddClause(where) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(tx.Statement.Selects) == 0 { |  | ||||||
| 		tx.Statement.Selects = []string{"*"} |  | ||||||
| 	} |  | ||||||
| 	tx.callbacks.Update().Execute(tx) | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | |||||||
| @ -10,6 +10,7 @@ import ( | |||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/jinzhu/gorm/utils" | ||||||
| 	"github.com/jinzhu/now" | 	"github.com/jinzhu/now" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -146,13 +147,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 		field.DBName = dbName | 		field.DBName = dbName | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { | 	if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { | ||||||
| 		field.PrimaryKey = true | 		field.PrimaryKey = true | ||||||
| 	} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { | 	} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { | ||||||
| 		field.PrimaryKey = true | 		field.PrimaryKey = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { | 	if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { | ||||||
| 		field.AutoIncrement = true | 		field.AutoIncrement = true | ||||||
| 		field.HasDefaultValue = true | 		field.HasDefaultValue = true | ||||||
| 	} | 	} | ||||||
| @ -173,11 +174,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 		field.Precision, _ = strconv.Atoi(p) | 		field.Precision, _ = strconv.Atoi(p) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) { | 	if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { | ||||||
| 		field.NotNull = true | 		field.NotNull = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) { | 	if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { | ||||||
| 		field.Unique = true | 		field.Unique = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -37,13 +37,6 @@ func ParseTagSetting(str string, sep string) map[string]string { | |||||||
| 	return settings | 	return settings | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func checkTruth(val string) bool { |  | ||||||
| 	if strings.ToLower(val) == "false" { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func toColumns(val string) (results []string) { | func toColumns(val string) (results []string) { | ||||||
| 	if val != "" { | 	if val != "" { | ||||||
| 		for _, v := range strings.Split(val, ",") { | 		for _, v := range strings.Split(val, ",") { | ||||||
|  | |||||||
| @ -2,8 +2,10 @@ package utils | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"reflect" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"runtime" | 	"runtime" | ||||||
|  | 	"strings" | ||||||
| 	"unicode" | 	"unicode" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -23,3 +25,16 @@ func FileWithLineNum() string { | |||||||
| func IsChar(c rune) bool { | func IsChar(c rune) bool { | ||||||
| 	return !unicode.IsLetter(c) && !unicode.IsNumber(c) | 	return !unicode.IsLetter(c) && !unicode.IsNumber(c) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func CheckTruth(val interface{}) bool { | ||||||
|  | 	if v, ok := val.(bool); ok { | ||||||
|  | 		return v | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if v, ok := val.(string); ok { | ||||||
|  | 		v = strings.ToLower(v) | ||||||
|  | 		return v != "false" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return !reflect.ValueOf(val).IsZero() | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu