Refactor if logic (#4683)
* adjust code for preload * adjust code for Create
This commit is contained in:
		
							parent
							
								
									c170af11e9
								
							
						
					
					
						commit
						4c8810a848
					
				| @ -65,66 +65,81 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 			db.Statement.Build(db.Statement.BuildClauses...) | ||||
| 		} | ||||
| 
 | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 		isDryRun := !db.DryRun && db.Error == nil | ||||
| 		if !isDryRun { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||
| 				if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | ||||
| 					if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { | ||||
| 						mode |= gorm.ScanOnConflictDoNothing | ||||
| 					} | ||||
| 		ok, mode := hasReturning(db, supportReturning) | ||||
| 		if ok { | ||||
| 			if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | ||||
| 				if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { | ||||
| 					mode |= gorm.ScanOnConflictDoNothing | ||||
| 				} | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 			} | ||||
| 
 | ||||
| 				if err != nil { | ||||
| 					db.AddError(err) | ||||
| 					return | ||||
| 				} | ||||
| 			rows, err := db.Statement.ConnPool.QueryContext( | ||||
| 				db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., | ||||
| 			) | ||||
| 			if db.AddError(err) == nil { | ||||
| 				gorm.Scan(rows, db, mode) | ||||
| 				rows.Close() | ||||
| 			} | ||||
| 
 | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 				if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||
| 					db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 					if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | ||||
| 						switch db.Statement.ReflectValue.Kind() { | ||||
| 						case reflect.Slice, reflect.Array: | ||||
| 							if config.LastInsertIDReversed { | ||||
| 								for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 									rv := db.Statement.ReflectValue.Index(i) | ||||
| 									if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||
| 										break | ||||
| 									} | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 									_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) | ||||
| 									if isZero { | ||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||
| 										insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 									} | ||||
| 								} | ||||
| 							} else { | ||||
| 								for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 									rv := db.Statement.ReflectValue.Index(i) | ||||
| 									if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||
| 										break | ||||
| 									} | ||||
| 		result, err := db.Statement.ConnPool.ExecContext( | ||||
| 			db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			db.AddError(err) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 									if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { | ||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||
| 										insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 						case reflect.Struct: | ||||
| 							if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { | ||||
| 								db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 							} | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 		if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||
| 			db.Statement.Schema.PrioritizedPrimaryField != nil && | ||||
| 			db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 			insertID, err := result.LastInsertId() | ||||
| 			insertOk := err == nil && insertID > 0 | ||||
| 			if !insertOk { | ||||
| 				db.AddError(err) | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				if config.LastInsertIDReversed { | ||||
| 					for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 						rv := db.Statement.ReflectValue.Index(i) | ||||
| 						if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) | ||||
| 						if isZero { | ||||
| 							db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||
| 							insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 						} | ||||
| 					} else { | ||||
| 						db.AddError(err) | ||||
| 					} | ||||
| 				} else { | ||||
| 					for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 						rv := db.Statement.ReflectValue.Index(i) | ||||
| 						if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { | ||||
| 							db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||
| 							insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) | ||||
| 				if isZero { | ||||
| 					db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -26,82 +26,87 @@ func BeforeDelete(db *gorm.DB) { | ||||
| func DeleteBeforeAssociations(db *gorm.DB) { | ||||
| 	if db.Error == nil && db.Statement.Schema != nil { | ||||
| 		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) | ||||
| 		if !restricted { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if restricted { | ||||
| 			for column, v := range selectColumns { | ||||
| 				if v { | ||||
| 					if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { | ||||
| 						switch rel.Type { | ||||
| 						case schema.HasOne, schema.HasMany: | ||||
| 							queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) | ||||
| 							modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||
| 							tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) | ||||
| 							withoutConditions := false | ||||
| 							if db.Statement.Unscoped { | ||||
| 								tx = tx.Unscoped() | ||||
| 							} | ||||
| 		for column, v := range selectColumns { | ||||
| 			if !v { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 							if len(db.Statement.Selects) > 0 { | ||||
| 								selects := make([]string, 0, len(db.Statement.Selects)) | ||||
| 								for _, s := range db.Statement.Selects { | ||||
| 									if s == clause.Associations { | ||||
| 										selects = append(selects, s) | ||||
| 									} else if strings.HasPrefix(s, column+".") { | ||||
| 										selects = append(selects, strings.TrimPrefix(s, column+".")) | ||||
| 									} | ||||
| 								} | ||||
| 			rel, ok := db.Statement.Schema.Relationships.Relations[column] | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 								if len(selects) > 0 { | ||||
| 									tx = tx.Select(selects) | ||||
| 								} | ||||
| 							} | ||||
| 			switch rel.Type { | ||||
| 			case schema.HasOne, schema.HasMany: | ||||
| 				queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) | ||||
| 				modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||
| 				tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) | ||||
| 				withoutConditions := false | ||||
| 				if db.Statement.Unscoped { | ||||
| 					tx = tx.Unscoped() | ||||
| 				} | ||||
| 
 | ||||
| 							for _, cond := range queryConds { | ||||
| 								if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { | ||||
| 									withoutConditions = true | ||||
| 									break | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							if !withoutConditions { | ||||
| 								if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||
| 									return | ||||
| 								} | ||||
| 							} | ||||
| 						case schema.Many2Many: | ||||
| 							var ( | ||||
| 								queryConds     = make([]clause.Expression, 0, len(rel.References)) | ||||
| 								foreignFields  = make([]*schema.Field, 0, len(rel.References)) | ||||
| 								relForeignKeys = make([]string, 0, len(rel.References)) | ||||
| 								modelValue     = reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 								table          = rel.JoinTable.Table | ||||
| 								tx             = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) | ||||
| 							) | ||||
| 
 | ||||
| 							for _, ref := range rel.References { | ||||
| 								if ref.OwnPrimaryKey { | ||||
| 									foreignFields = append(foreignFields, ref.PrimaryKey) | ||||
| 									relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||
| 								} else if ref.PrimaryValue != "" { | ||||
| 									queryConds = append(queryConds, clause.Eq{ | ||||
| 										Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, | ||||
| 										Value:  ref.PrimaryValue, | ||||
| 									}) | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) | ||||
| 							column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) | ||||
| 							queryConds = append(queryConds, clause.IN{Column: column, Values: values}) | ||||
| 
 | ||||
| 							if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||
| 								return | ||||
| 							} | ||||
| 				if len(db.Statement.Selects) > 0 { | ||||
| 					selects := make([]string, 0, len(db.Statement.Selects)) | ||||
| 					for _, s := range db.Statement.Selects { | ||||
| 						if s == clause.Associations { | ||||
| 							selects = append(selects, s) | ||||
| 						} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { | ||||
| 							selects = append(selects, strings.TrimPrefix(s, columnPrefix)) | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					if len(selects) > 0 { | ||||
| 						tx = tx.Select(selects) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				for _, cond := range queryConds { | ||||
| 					if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { | ||||
| 						withoutConditions = true | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||
| 					return | ||||
| 				} | ||||
| 			case schema.Many2Many: | ||||
| 				var ( | ||||
| 					queryConds     = make([]clause.Expression, 0, len(rel.References)) | ||||
| 					foreignFields  = make([]*schema.Field, 0, len(rel.References)) | ||||
| 					relForeignKeys = make([]string, 0, len(rel.References)) | ||||
| 					modelValue     = reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 					table          = rel.JoinTable.Table | ||||
| 					tx             = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) | ||||
| 				) | ||||
| 
 | ||||
| 				for _, ref := range rel.References { | ||||
| 					if ref.OwnPrimaryKey { | ||||
| 						foreignFields = append(foreignFields, ref.PrimaryKey) | ||||
| 						relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||
| 					} else if ref.PrimaryValue != "" { | ||||
| 						queryConds = append(queryConds, clause.Eq{ | ||||
| 							Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, | ||||
| 							Value:  ref.PrimaryValue, | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) | ||||
| 				column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) | ||||
| 				queryConds = append(queryConds, clause.IN{Column: column, Values: values}) | ||||
| 
 | ||||
| 				if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -145,27 +145,30 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			fieldValues[idx], _ = field.ValueOf(elem) | ||||
| 		} | ||||
| 
 | ||||
| 		if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { | ||||
| 			for _, data := range datas { | ||||
| 				reflectFieldValue := rel.Field.ReflectValueOf(data) | ||||
| 				if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { | ||||
| 					reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) | ||||
| 				} | ||||
| 		datas, ok := identityMap[utils.ToStringKey(fieldValues...)] | ||||
| 		if !ok { | ||||
| 			db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", | ||||
| 				elem.Interface())) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 				reflectFieldValue = reflect.Indirect(reflectFieldValue) | ||||
| 				switch reflectFieldValue.Kind() { | ||||
| 				case reflect.Struct: | ||||
| 					rel.Field.Set(data, reflectResults.Index(i).Interface()) | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { | ||||
| 						rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) | ||||
| 					} else { | ||||
| 						rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) | ||||
| 					} | ||||
| 		for _, data := range datas { | ||||
| 			reflectFieldValue := rel.Field.ReflectValueOf(data) | ||||
| 			if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { | ||||
| 				reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) | ||||
| 			} | ||||
| 
 | ||||
| 			reflectFieldValue = reflect.Indirect(reflectFieldValue) | ||||
| 			switch reflectFieldValue.Kind() { | ||||
| 			case reflect.Struct: | ||||
| 				rel.Field.Set(data, elem.Interface()) | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { | ||||
| 					rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) | ||||
| 				} else { | ||||
| 					rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 heige
						heige