Refactor if logic (#4683)
* adjust code for preload * adjust code for Create
This commit is contained in:
		
							parent
							
								
									c170af11e9
								
							
						
					
					
						commit
						4c8810a848
					
				| @ -65,66 +65,81 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 			db.Statement.Build(db.Statement.BuildClauses...) | 			db.Statement.Build(db.Statement.BuildClauses...) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun && db.Error == nil { | 		isDryRun := !db.DryRun && db.Error == nil | ||||||
|  | 		if !isDryRun { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | 		ok, mode := hasReturning(db, supportReturning) | ||||||
| 				if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | 		if ok { | ||||||
| 					if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { | 			if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | ||||||
| 						mode |= gorm.ScanOnConflictDoNothing | 				if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { | ||||||
| 					} | 					mode |= gorm.ScanOnConflictDoNothing | ||||||
| 				} | 				} | ||||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | 			} | ||||||
| 					gorm.Scan(rows, db, mode) |  | ||||||
| 					rows.Close() |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 
 | 
 | ||||||
| 				if err != nil { | 			rows, err := db.Statement.ConnPool.QueryContext( | ||||||
| 					db.AddError(err) | 				db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., | ||||||
| 					return | 			) | ||||||
| 				} | 			if db.AddError(err) == nil { | ||||||
|  | 				gorm.Scan(rows, db, mode) | ||||||
|  | 				rows.Close() | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 				db.RowsAffected, _ = result.RowsAffected() | 			return | ||||||
| 				if db.RowsAffected != 0 && db.Statement.Schema != nil && | 		} | ||||||
| 					db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { |  | ||||||
| 					if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { |  | ||||||
| 						switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 						case reflect.Slice, reflect.Array: |  | ||||||
| 							if config.LastInsertIDReversed { |  | ||||||
| 								for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { |  | ||||||
| 									rv := db.Statement.ReflectValue.Index(i) |  | ||||||
| 									if reflect.Indirect(rv).Kind() != reflect.Struct { |  | ||||||
| 										break |  | ||||||
| 									} |  | ||||||
| 
 | 
 | ||||||
| 									_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) | 		result, err := db.Statement.ConnPool.ExecContext( | ||||||
| 									if isZero { | 			db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., | ||||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | 		) | ||||||
| 										insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | 		if err != nil { | ||||||
| 									} | 			db.AddError(err) | ||||||
| 								} | 			return | ||||||
| 							} else { | 		} | ||||||
| 								for i := 0; i < db.Statement.ReflectValue.Len(); i++ { |  | ||||||
| 									rv := db.Statement.ReflectValue.Index(i) |  | ||||||
| 									if reflect.Indirect(rv).Kind() != reflect.Struct { |  | ||||||
| 										break |  | ||||||
| 									} |  | ||||||
| 
 | 
 | ||||||
| 									if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { | 		db.RowsAffected, _ = result.RowsAffected() | ||||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | 		if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||||
| 										insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | 			db.Statement.Schema.PrioritizedPrimaryField != nil && | ||||||
| 									} | 			db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||||
| 								} | 			insertID, err := result.LastInsertId() | ||||||
| 							} | 			insertOk := err == nil && insertID > 0 | ||||||
| 						case reflect.Struct: | 			if !insertOk { | ||||||
| 							if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { | 				db.AddError(err) | ||||||
| 								db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | 				return | ||||||
| 							} | 			} | ||||||
|  | 
 | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				if config.LastInsertIDReversed { | ||||||
|  | 					for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||||
|  | 						rv := db.Statement.ReflectValue.Index(i) | ||||||
|  | 						if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||||
|  | 							break | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) | ||||||
|  | 						if isZero { | ||||||
|  | 							db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||||
|  | 							insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||||
| 						} | 						} | ||||||
| 					} else { |  | ||||||
| 						db.AddError(err) |  | ||||||
| 					} | 					} | ||||||
|  | 				} else { | ||||||
|  | 					for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
|  | 						rv := db.Statement.ReflectValue.Index(i) | ||||||
|  | 						if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||||
|  | 							break | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { | ||||||
|  | 							db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||||
|  | 							insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) | ||||||
|  | 				if isZero { | ||||||
|  | 					db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -26,82 +26,87 @@ func BeforeDelete(db *gorm.DB) { | |||||||
| func DeleteBeforeAssociations(db *gorm.DB) { | func DeleteBeforeAssociations(db *gorm.DB) { | ||||||
| 	if db.Error == nil && db.Statement.Schema != nil { | 	if db.Error == nil && db.Statement.Schema != nil { | ||||||
| 		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) | 		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) | ||||||
|  | 		if !restricted { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		if restricted { | 		for column, v := range selectColumns { | ||||||
| 			for column, v := range selectColumns { | 			if !v { | ||||||
| 				if v { | 				continue | ||||||
| 					if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { | 			} | ||||||
| 						switch rel.Type { |  | ||||||
| 						case schema.HasOne, schema.HasMany: |  | ||||||
| 							queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) |  | ||||||
| 							modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() |  | ||||||
| 							tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) |  | ||||||
| 							withoutConditions := false |  | ||||||
| 							if db.Statement.Unscoped { |  | ||||||
| 								tx = tx.Unscoped() |  | ||||||
| 							} |  | ||||||
| 
 | 
 | ||||||
| 							if len(db.Statement.Selects) > 0 { | 			rel, ok := db.Statement.Schema.Relationships.Relations[column] | ||||||
| 								selects := make([]string, 0, len(db.Statement.Selects)) | 			if !ok { | ||||||
| 								for _, s := range db.Statement.Selects { | 				continue | ||||||
| 									if s == clause.Associations { | 			} | ||||||
| 										selects = append(selects, s) |  | ||||||
| 									} else if strings.HasPrefix(s, column+".") { |  | ||||||
| 										selects = append(selects, strings.TrimPrefix(s, column+".")) |  | ||||||
| 									} |  | ||||||
| 								} |  | ||||||
| 
 | 
 | ||||||
| 								if len(selects) > 0 { | 			switch rel.Type { | ||||||
| 									tx = tx.Select(selects) | 			case schema.HasOne, schema.HasMany: | ||||||
| 								} | 				queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) | ||||||
| 							} | 				modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||||
|  | 				tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) | ||||||
|  | 				withoutConditions := false | ||||||
|  | 				if db.Statement.Unscoped { | ||||||
|  | 					tx = tx.Unscoped() | ||||||
|  | 				} | ||||||
| 
 | 
 | ||||||
| 							for _, cond := range queryConds { | 				if len(db.Statement.Selects) > 0 { | ||||||
| 								if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { | 					selects := make([]string, 0, len(db.Statement.Selects)) | ||||||
| 									withoutConditions = true | 					for _, s := range db.Statement.Selects { | ||||||
| 									break | 						if s == clause.Associations { | ||||||
| 								} | 							selects = append(selects, s) | ||||||
| 							} | 						} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { | ||||||
| 
 | 							selects = append(selects, strings.TrimPrefix(s, columnPrefix)) | ||||||
| 							if !withoutConditions { |  | ||||||
| 								if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { |  | ||||||
| 									return |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 						case schema.Many2Many: |  | ||||||
| 							var ( |  | ||||||
| 								queryConds     = make([]clause.Expression, 0, len(rel.References)) |  | ||||||
| 								foreignFields  = make([]*schema.Field, 0, len(rel.References)) |  | ||||||
| 								relForeignKeys = make([]string, 0, len(rel.References)) |  | ||||||
| 								modelValue     = reflect.New(rel.JoinTable.ModelType).Interface() |  | ||||||
| 								table          = rel.JoinTable.Table |  | ||||||
| 								tx             = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) |  | ||||||
| 							) |  | ||||||
| 
 |  | ||||||
| 							for _, ref := range rel.References { |  | ||||||
| 								if ref.OwnPrimaryKey { |  | ||||||
| 									foreignFields = append(foreignFields, ref.PrimaryKey) |  | ||||||
| 									relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) |  | ||||||
| 								} else if ref.PrimaryValue != "" { |  | ||||||
| 									queryConds = append(queryConds, clause.Eq{ |  | ||||||
| 										Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, |  | ||||||
| 										Value:  ref.PrimaryValue, |  | ||||||
| 									}) |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 
 |  | ||||||
| 							_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) |  | ||||||
| 							column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) |  | ||||||
| 							queryConds = append(queryConds, clause.IN{Column: column, Values: values}) |  | ||||||
| 
 |  | ||||||
| 							if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { |  | ||||||
| 								return |  | ||||||
| 							} |  | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
|  | 
 | ||||||
|  | 					if len(selects) > 0 { | ||||||
|  | 						tx = tx.Select(selects) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				for _, cond := range queryConds { | ||||||
|  | 					if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { | ||||||
|  | 						withoutConditions = true | ||||||
|  | 						break | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			case schema.Many2Many: | ||||||
|  | 				var ( | ||||||
|  | 					queryConds     = make([]clause.Expression, 0, len(rel.References)) | ||||||
|  | 					foreignFields  = make([]*schema.Field, 0, len(rel.References)) | ||||||
|  | 					relForeignKeys = make([]string, 0, len(rel.References)) | ||||||
|  | 					modelValue     = reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
|  | 					table          = rel.JoinTable.Table | ||||||
|  | 					tx             = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) | ||||||
|  | 				) | ||||||
|  | 
 | ||||||
|  | 				for _, ref := range rel.References { | ||||||
|  | 					if ref.OwnPrimaryKey { | ||||||
|  | 						foreignFields = append(foreignFields, ref.PrimaryKey) | ||||||
|  | 						relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||||
|  | 					} else if ref.PrimaryValue != "" { | ||||||
|  | 						queryConds = append(queryConds, clause.Eq{ | ||||||
|  | 							Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, | ||||||
|  | 							Value:  ref.PrimaryValue, | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) | ||||||
|  | 				column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) | ||||||
|  | 				queryConds = append(queryConds, clause.IN{Column: column, Values: values}) | ||||||
|  | 
 | ||||||
|  | 				if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||||
|  | 					return | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -145,27 +145,30 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | |||||||
| 			fieldValues[idx], _ = field.ValueOf(elem) | 			fieldValues[idx], _ = field.ValueOf(elem) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { | 		datas, ok := identityMap[utils.ToStringKey(fieldValues...)] | ||||||
| 			for _, data := range datas { | 		if !ok { | ||||||
| 				reflectFieldValue := rel.Field.ReflectValueOf(data) | 			db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", | ||||||
| 				if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { | 				elem.Interface())) | ||||||
| 					reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) | 			continue | ||||||
| 				} | 		} | ||||||
| 
 | 
 | ||||||
| 				reflectFieldValue = reflect.Indirect(reflectFieldValue) | 		for _, data := range datas { | ||||||
| 				switch reflectFieldValue.Kind() { | 			reflectFieldValue := rel.Field.ReflectValueOf(data) | ||||||
| 				case reflect.Struct: | 			if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { | ||||||
| 					rel.Field.Set(data, reflectResults.Index(i).Interface()) | 				reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) | ||||||
| 				case reflect.Slice, reflect.Array: | 			} | ||||||
| 					if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { | 
 | ||||||
| 						rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) | 			reflectFieldValue = reflect.Indirect(reflectFieldValue) | ||||||
| 					} else { | 			switch reflectFieldValue.Kind() { | ||||||
| 						rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) | 			case reflect.Struct: | ||||||
| 					} | 				rel.Field.Set(data, elem.Interface()) | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { | ||||||
|  | 					rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) | ||||||
|  | 				} else { | ||||||
|  | 					rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} else { |  | ||||||
| 			db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 heige
						heige