Refactor create callback
This commit is contained in:
		
							parent
							
								
									e38b1e0948
								
							
						
					
					
						commit
						92213273a5
					
				@ -42,30 +42,29 @@ func createCallback(scope *Scope) {
 | 
				
			|||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		defer scope.trace(NowFunc())
 | 
							defer scope.trace(NowFunc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// set create sql
 | 
							var (
 | 
				
			||||||
		var sqls, columns []string
 | 
								columns, placeholders        []string
 | 
				
			||||||
		fields := scope.Fields()
 | 
								blankColumnsWithDefaultValue []string
 | 
				
			||||||
 | 
								fields                       = scope.Fields()
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, field := range fields {
 | 
							for _, field := range fields {
 | 
				
			||||||
			if scope.changeableField(field) {
 | 
								if scope.changeableField(field) {
 | 
				
			||||||
				if field.IsNormal {
 | 
									if field.IsNormal {
 | 
				
			||||||
					if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
 | 
										if !field.IsPrimaryKey || !field.IsBlank {
 | 
				
			||||||
						if !field.IsBlank || !field.HasDefaultValue {
 | 
											if field.IsBlank && field.HasDefaultValue {
 | 
				
			||||||
 | 
												blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
 | 
				
			||||||
 | 
												scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
 | 
				
			||||||
 | 
											} else {
 | 
				
			||||||
							columns = append(columns, scope.Quote(field.DBName))
 | 
												columns = append(columns, scope.Quote(field.DBName))
 | 
				
			||||||
							sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
 | 
												placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
 | 
				
			||||||
						} else if field.HasDefaultValue {
 | 
					 | 
				
			||||||
							var hasDefaultValueColumns []string
 | 
					 | 
				
			||||||
							if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
 | 
					 | 
				
			||||||
								hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
							hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
 | 
					 | 
				
			||||||
							scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
									} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
 | 
				
			||||||
					for _, dbName := range relationship.ForeignDBNames {
 | 
										for _, foreignKey := range field.Relationship.ForeignDBNames {
 | 
				
			||||||
						if relationField := fields[dbName]; !scope.changeableField(relationField) {
 | 
											if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
 | 
				
			||||||
							columns = append(columns, scope.Quote(relationField.DBName))
 | 
												columns = append(columns, scope.Quote(foreignField.DBName))
 | 
				
			||||||
							sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
 | 
												placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -88,35 +87,27 @@ func createCallback(scope *Scope) {
 | 
				
			|||||||
				"INSERT INTO %v (%v) VALUES (%v) %v",
 | 
									"INSERT INTO %v (%v) VALUES (%v) %v",
 | 
				
			||||||
				scope.QuotedTableName(),
 | 
									scope.QuotedTableName(),
 | 
				
			||||||
				strings.Join(columns, ","),
 | 
									strings.Join(columns, ","),
 | 
				
			||||||
				strings.Join(sqls, ","),
 | 
									strings.Join(placeholders, ","),
 | 
				
			||||||
				scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
 | 
									scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
 | 
				
			||||||
			))
 | 
								))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// execute create sql
 | 
							// execute create sql
 | 
				
			||||||
		if scope.Dialect().SupportLastInsertId() {
 | 
							if scope.Dialect().SupportLastInsertId() || primaryField == nil {
 | 
				
			||||||
			if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
 | 
								if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
 | 
				
			||||||
				id, err := result.LastInsertId()
 | 
									// set rows affected count
 | 
				
			||||||
				if scope.Err(err) == nil {
 | 
									scope.db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
					scope.db.RowsAffected, _ = result.RowsAffected()
 | 
					
 | 
				
			||||||
					if primaryField != nil && primaryField.IsBlank {
 | 
									// set primary value to primary field
 | 
				
			||||||
						scope.Err(scope.SetColumn(primaryField, id))
 | 
									if primaryField != nil && primaryField.IsBlank {
 | 
				
			||||||
 | 
										if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
 | 
				
			||||||
 | 
											scope.Err(primaryField.Set(primaryValue))
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if primaryField == nil {
 | 
								if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
 | 
				
			||||||
				if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
 | 
									scope.db.RowsAffected = 1
 | 
				
			||||||
					scope.db.RowsAffected, _ = results.RowsAffected()
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					scope.Err(err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
 | 
					 | 
				
			||||||
					scope.db.RowsAffected = 1
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					scope.Err(err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -124,8 +115,8 @@ func createCallback(scope *Scope) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
 | 
					// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
 | 
				
			||||||
func forceReloadAfterCreateCallback(scope *Scope) {
 | 
					func forceReloadAfterCreateCallback(scope *Scope) {
 | 
				
			||||||
	if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
 | 
						if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
 | 
				
			||||||
		scope.DB().New().Select(columns.([]string)).First(scope.Value)
 | 
							scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										28
									
								
								field.go
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								field.go
									
									
									
									
									
								
							@ -58,15 +58,20 @@ func (field *Field) Set(value interface{}) (err error) {
 | 
				
			|||||||
// Fields get value's fields
 | 
					// Fields get value's fields
 | 
				
			||||||
func (scope *Scope) Fields() map[string]*Field {
 | 
					func (scope *Scope) Fields() map[string]*Field {
 | 
				
			||||||
	if scope.fields == nil {
 | 
						if scope.fields == nil {
 | 
				
			||||||
		fields := map[string]*Field{}
 | 
							var (
 | 
				
			||||||
		modelStruct := scope.GetModelStruct()
 | 
								fields             = map[string]*Field{}
 | 
				
			||||||
 | 
								indirectScopeValue = scope.IndirectValue()
 | 
				
			||||||
 | 
								isStruct           = indirectScopeValue.Kind() == reflect.Struct
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		indirectValue := scope.IndirectValue()
 | 
							for _, structField := range scope.GetModelStruct().StructFields {
 | 
				
			||||||
		isStruct := indirectValue.Kind() == reflect.Struct
 | 
					 | 
				
			||||||
		for _, structField := range modelStruct.StructFields {
 | 
					 | 
				
			||||||
			if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
 | 
								if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
 | 
				
			||||||
				if isStruct {
 | 
									if isStruct {
 | 
				
			||||||
					fields[structField.DBName] = getField(indirectValue, structField)
 | 
										fieldValue := indirectScopeValue
 | 
				
			||||||
 | 
										for _, name := range structField.Names {
 | 
				
			||||||
 | 
											fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
 | 
										fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -74,17 +79,6 @@ func (scope *Scope) Fields() map[string]*Field {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		scope.fields = fields
 | 
							scope.fields = fields
 | 
				
			||||||
		return fields
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return scope.fields
 | 
						return scope.fields
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
 | 
					 | 
				
			||||||
	field := &Field{StructField: structField}
 | 
					 | 
				
			||||||
	for _, name := range structField.Names {
 | 
					 | 
				
			||||||
		indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	field.Field = indirectValue
 | 
					 | 
				
			||||||
	field.IsBlank = isBlank(indirectValue)
 | 
					 | 
				
			||||||
	return field
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										10
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								utils.go
									
									
									
									
									
								
							@ -132,11 +132,11 @@ func toQueryCondition(scope *Scope, columns []string) string {
 | 
				
			|||||||
	return strings.Join(newColumns, ",")
 | 
						return strings.Join(newColumns, ",")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
 | 
					func toQueryValues(values [][]interface{}) (results []interface{}) {
 | 
				
			||||||
	for _, primaryValue := range primaryValues {
 | 
						for _, value := range values {
 | 
				
			||||||
		for _, value := range primaryValue {
 | 
							for _, v := range value {
 | 
				
			||||||
			values = append(values, value)
 | 
								results = append(results, v)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return values
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user