Test Association For BelongsTo
This commit is contained in:
		
							parent
							
								
									cbc4a81140
								
							
						
					
					
						commit
						91a695893c
					
				| @ -19,8 +19,10 @@ type Association struct { | |||||||
| 
 | 
 | ||||||
| func (db *DB) Association(column string) *Association { | func (db *DB) Association(column string) *Association { | ||||||
| 	association := &Association{DB: db} | 	association := &Association{DB: db} | ||||||
|  | 	table := db.Statement.Table | ||||||
| 
 | 
 | ||||||
| 	if err := db.Statement.Parse(db.Statement.Model); err == nil { | 	if err := db.Statement.Parse(db.Statement.Model); err == nil { | ||||||
|  | 		db.Statement.Table = table | ||||||
| 		association.Relationship = db.Statement.Schema.Relationships.Relations[column] | 		association.Relationship = db.Statement.Schema.Relationships.Relations[column] | ||||||
| 
 | 
 | ||||||
| 		if association.Relationship == nil { | 		if association.Relationship == nil { | ||||||
| @ -83,6 +85,16 @@ func (association *Association) Replace(values ...interface{}) error { | |||||||
| 		rel := association.Relationship | 		rel := association.Relationship | ||||||
| 
 | 
 | ||||||
| 		switch rel.Type { | 		switch rel.Type { | ||||||
|  | 		case schema.BelongsTo: | ||||||
|  | 			if len(values) == 0 { | ||||||
|  | 				updateMap := map[string]interface{}{} | ||||||
|  | 
 | ||||||
|  | 				for _, ref := range rel.References { | ||||||
|  | 					updateMap[ref.ForeignKey.DBName] = nil | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				association.DB.UpdateColumns(updateMap) | ||||||
|  | 			} | ||||||
| 		case schema.HasOne, schema.HasMany: | 		case schema.HasOne, schema.HasMany: | ||||||
| 			var ( | 			var ( | ||||||
| 				primaryFields []*schema.Field | 				primaryFields []*schema.Field | ||||||
| @ -90,6 +102,9 @@ func (association *Association) Replace(values ...interface{}) error { | |||||||
| 				updateMap     = map[string]interface{}{} | 				updateMap     = map[string]interface{}{} | ||||||
| 				modelValue    = reflect.New(rel.FieldSchema.ModelType).Interface() | 				modelValue    = reflect.New(rel.FieldSchema.ModelType).Interface() | ||||||
| 			) | 			) | ||||||
|  | 			if rel.Type == schema.BelongsTo { | ||||||
|  | 				modelValue = reflect.New(rel.Schema.ModelType).Interface() | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 			for _, ref := range rel.References { | 			for _, ref := range rel.References { | ||||||
| 				if ref.OwnPrimaryKey { | 				if ref.OwnPrimaryKey { | ||||||
| @ -101,7 +116,7 @@ func (association *Association) Replace(values ...interface{}) error { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) | 			_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) | ||||||
| 			if len(values) > 0 { | 			if len(values) == 0 { | ||||||
| 				column, queryValues := schema.ToQueryValues(foreignKeys, values) | 				column, queryValues := schema.ToQueryValues(foreignKeys, values) | ||||||
| 				association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) | 				association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) | ||||||
| 			} | 			} | ||||||
| @ -158,13 +173,13 @@ func (association *Association) Replace(values ...interface{}) error { | |||||||
| func (association *Association) Delete(values ...interface{}) error { | func (association *Association) Delete(values ...interface{}) error { | ||||||
| 	if association.Error == nil { | 	if association.Error == nil { | ||||||
| 		var ( | 		var ( | ||||||
| 			tx           = association.DB | 			tx               = association.DB | ||||||
| 			rel          = association.Relationship | 			rel              = association.Relationship | ||||||
| 			reflectValue = tx.Statement.ReflectValue | 			reflectValue     = tx.Statement.ReflectValue | ||||||
| 			conds        = rel.ToQueryConditions(reflectValue) | 			relFields        []*schema.Field | ||||||
| 			relFields    []*schema.Field | 			foreignKeyFields []*schema.Field | ||||||
| 			foreignKeys  []string | 			foreignKeys      []string | ||||||
| 			updateAttrs  = map[string]interface{}{} | 			updateAttrs      = map[string]interface{}{} | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		for _, ref := range rel.References { | 		for _, ref := range rel.References { | ||||||
| @ -174,6 +189,7 @@ func (association *Association) Delete(values ...interface{}) error { | |||||||
| 						relFields = append(relFields, ref.ForeignKey) | 						relFields = append(relFields, ref.ForeignKey) | ||||||
| 					} else { | 					} else { | ||||||
| 						relFields = append(relFields, ref.PrimaryKey) | 						relFields = append(relFields, ref.PrimaryKey) | ||||||
|  | 						foreignKeyFields = append(foreignKeyFields, ref.ForeignKey) | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) | 					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) | ||||||
| @ -189,11 +205,14 @@ func (association *Association) Delete(values ...interface{}) error { | |||||||
| 		switch rel.Type { | 		switch rel.Type { | ||||||
| 		case schema.HasOne, schema.HasMany: | 		case schema.HasOne, schema.HasMany: | ||||||
| 			modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | 			modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||||
|  | 			conds := rel.ToQueryConditions(reflectValue) | ||||||
| 			tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) | 			tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) | ||||||
| 		case schema.BelongsTo: | 		case schema.BelongsTo: | ||||||
| 			tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) | 			modelValue := reflect.New(rel.Schema.ModelType).Interface() | ||||||
|  | 			tx.Model(modelValue).UpdateColumns(updateAttrs) | ||||||
| 		case schema.Many2Many: | 		case schema.Many2Many: | ||||||
| 			modelValue := reflect.New(rel.JoinTable.ModelType).Interface() | 			modelValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
|  | 			conds := rel.ToQueryConditions(reflectValue) | ||||||
| 			tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) | 			tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -216,13 +235,16 @@ func (association *Association) Delete(values ...interface{}) error { | |||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						rel.Field.Set(data, validFieldValues) | 						rel.Field.Set(data, validFieldValues.Interface()) | ||||||
| 					case reflect.Struct: | 					case reflect.Struct: | ||||||
| 						for idx, field := range relFields { | 						for idx, field := range relFields { | ||||||
| 							fieldValues[idx], _ = field.ValueOf(data) | 							fieldValues[idx], _ = field.ValueOf(fieldValue) | ||||||
| 						} | 						} | ||||||
| 						if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { | 						if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { | ||||||
| 							rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) | 							rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) | ||||||
|  | 							for _, field := range foreignKeyFields { | ||||||
|  | 								field.Set(data, reflect.Zero(field.FieldType).Interface()) | ||||||
|  | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| @ -275,7 +297,11 @@ func (association *Association) Count() (count int64) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (association *Association) saveAssociation(clear bool, values ...interface{}) { | func (association *Association) saveAssociation(clear bool, values ...interface{}) { | ||||||
| 	reflectValue := association.DB.Statement.ReflectValue | 	var ( | ||||||
|  | 		reflectValue = association.DB.Statement.ReflectValue | ||||||
|  | 		assignBacks  = [][2]reflect.Value{} | ||||||
|  | 		assignBack   = association.Relationship.Field.FieldType.Kind() == reflect.Struct | ||||||
|  | 	) | ||||||
| 
 | 
 | ||||||
| 	appendToRelations := func(source, rv reflect.Value, clear bool) { | 	appendToRelations := func(source, rv reflect.Value, clear bool) { | ||||||
| 		switch association.Relationship.Type { | 		switch association.Relationship.Type { | ||||||
| @ -283,10 +309,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 			switch rv.Kind() { | 			switch rv.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
| 				if rv.Len() > 0 { | 				if rv.Len() > 0 { | ||||||
| 					association.Error = association.Relationship.Field.Set(source, rv.Index(0)) | 					association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) | ||||||
|  | 					if assignBack { | ||||||
|  | 						assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)}) | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
| 				association.Error = association.Relationship.Field.Set(source, rv) | 				association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) | ||||||
|  | 				if assignBack { | ||||||
|  | 					assignBacks = append(assignBacks, [2]reflect.Value{source, rv}) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		case schema.HasMany, schema.Many2Many: | 		case schema.HasMany, schema.Many2Many: | ||||||
| 			elemType := association.Relationship.Field.IndirectFieldType.Elem() | 			elemType := association.Relationship.Field.IndirectFieldType.Elem() | ||||||
| @ -315,7 +347,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if association.Error == nil { | 			if association.Error == nil { | ||||||
| 				association.Error = association.Relationship.Field.Set(source, fieldValue) | 				association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface()) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @ -333,7 +365,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 		if len(values) != reflectValue.Len() { | 		if len(values) != reflectValue.Len() { | ||||||
| 			if clear && len(values) == 0 { | 			if clear && len(values) == 0 { | ||||||
| 				for i := 0; i < reflectValue.Len(); i++ { | 				for i := 0; i < reflectValue.Len(); i++ { | ||||||
| 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) | 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | ||||||
| 				} | 				} | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| @ -349,19 +381,24 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 		} | 		} | ||||||
| 	case reflect.Struct: | 	case reflect.Struct: | ||||||
| 		if clear && len(values) == 0 { | 		if clear && len(values) == 0 { | ||||||
| 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) | 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for idx, value := range values { | 		for idx, value := range values { | ||||||
| 			appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) | 			rv := reflect.Indirect(reflect.ValueOf(value)) | ||||||
|  | 			appendToRelations(reflectValue, rv, clear && idx == 0) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if hasZero { | 	if hasZero { | ||||||
| 		association.DB.Save(reflectValue.Interface()) | 		association.DB.Save(reflectValue.Addr().Interface()) | ||||||
| 	} else { | 	} else { | ||||||
| 		association.DB.Select(selectedColumns).Save(reflectValue.Interface()) | 		association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, assignBack := range assignBacks { | ||||||
|  | 		reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0])) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -73,8 +73,8 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 					if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { | 					if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { | ||||||
| 						db.Session(&gorm.Session{}).Create(rv.Interface()) | 						db.Session(&gorm.Session{}).Create(rv.Interface()) | ||||||
| 						setupReferences(db.Statement.ReflectValue, rv) |  | ||||||
| 					} | 					} | ||||||
|  | 					setupReferences(db.Statement.ReflectValue, rv) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo | |||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if field := stmt.Schema.LookUpField(column); field != nil { | 		if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { | ||||||
| 			results[field.DBName] = true | 			results[field.DBName] = true | ||||||
| 		} else { | 		} else { | ||||||
| 			results[column] = true | 			results[column] = true | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/clause" | 	"github.com/jinzhu/gorm/clause" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BeforeUpdate(db *gorm.DB) { | func BeforeUpdate(db *gorm.DB) { | ||||||
| @ -91,8 +92,27 @@ func AfterUpdate(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| // ConvertToAssignments convert to update assignments
 | // ConvertToAssignments convert to update assignments
 | ||||||
| func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||||
| 	selectColumns, restricted := SelectAndOmitColumns(stmt, false, true) | 	var ( | ||||||
| 	reflectModelValue := reflect.ValueOf(stmt.Model) | 		selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) | ||||||
|  | 		reflectModelValue         = reflect.Indirect(reflect.ValueOf(stmt.Model)) | ||||||
|  | 		assignValue               func(field *schema.Field, value interface{}) | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	switch reflectModelValue.Kind() { | ||||||
|  | 	case reflect.Slice, reflect.Array: | ||||||
|  | 		assignValue = func(field *schema.Field, value interface{}) { | ||||||
|  | 			for i := 0; i < reflectModelValue.Len(); i++ { | ||||||
|  | 				field.Set(reflectModelValue.Index(i), value) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	case reflect.Struct: | ||||||
|  | 		assignValue = func(field *schema.Field, value interface{}) { | ||||||
|  | 			field.Set(reflectModelValue, value) | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		assignValue = func(field *schema.Field, value interface{}) { | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	switch value := stmt.Dest.(type) { | 	switch value := stmt.Dest.(type) { | ||||||
| 	case map[string]interface{}: | 	case map[string]interface{}: | ||||||
| @ -111,7 +131,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 						value[k] = time.Now() | 						value[k] = time.Now() | ||||||
| 					} | 					} | ||||||
| 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) | 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) | ||||||
| 					field.Set(reflectModelValue, value[k]) | 					assignValue(field, value[k]) | ||||||
| 				} | 				} | ||||||
| 			} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | 			} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||||
| 				set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) | 				set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) | ||||||
| @ -122,7 +142,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 			if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { | 			if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { | ||||||
| 				now := time.Now() | 				now := time.Now() | ||||||
| 				set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) | 				set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) | ||||||
| 				field.Set(reflectModelValue, now) | 				assignValue(field, now) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
| @ -140,7 +160,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 
 | 
 | ||||||
| 						if ok || !isZero { | 						if ok || !isZero { | ||||||
| 							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) | 							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) | ||||||
| 							field.Set(reflectModelValue, value) | 							assignValue(field, value) | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} else { | 				} else { | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								gorm.go
									
									
									
									
									
								
							| @ -105,11 +105,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | |||||||
| func (db *DB) Session(config *Session) *DB { | func (db *DB) Session(config *Session) *DB { | ||||||
| 	var ( | 	var ( | ||||||
| 		tx       = db.getInstance() | 		tx       = db.getInstance() | ||||||
|  | 		stmt     = tx.Statement.clone() | ||||||
| 		txConfig = *tx.Config | 		txConfig = *tx.Config | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	if config.Context != nil { | 	if config.Context != nil { | ||||||
| 		tx.Statement.Context = config.Context | 		stmt.Context = config.Context | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if config.Logger != nil { | 	if config.Logger != nil { | ||||||
| @ -120,9 +121,11 @@ func (db *DB) Session(config *Session) *DB { | |||||||
| 		txConfig.NowFunc = config.NowFunc | 		txConfig.NowFunc = config.NowFunc | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tx.Config = &txConfig | 	return &DB{ | ||||||
| 	tx.clone = true | 		Config:    &txConfig, | ||||||
| 	return tx | 		Statement: stmt, | ||||||
|  | 		clone:     true, | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // WithContext change current instance db's context to ctx
 | // WithContext change current instance db's context to ctx
 | ||||||
|  | |||||||
| @ -372,19 +372,24 @@ func (field *Field) setupValuerAndSetter() { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { | 	recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { | ||||||
| 		reflectV := reflect.ValueOf(v) | 		if v == nil { | ||||||
| 		if reflectV.Type().ConvertibleTo(field.FieldType) { | 			field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) | ||||||
| 			field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) |  | ||||||
| 		} else if valuer, ok := v.(driver.Valuer); ok { |  | ||||||
| 			if v, err = valuer.Value(); err == nil { |  | ||||||
| 				return setter(value, v) |  | ||||||
| 			} |  | ||||||
| 		} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { |  | ||||||
| 			field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) |  | ||||||
| 		} else if reflectV.Kind() == reflect.Ptr { |  | ||||||
| 			return field.Set(value, reflectV.Elem().Interface()) |  | ||||||
| 		} else { | 		} else { | ||||||
| 			return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) | 			reflectV := reflect.ValueOf(v) | ||||||
|  | 
 | ||||||
|  | 			if reflectV.Type().ConvertibleTo(field.FieldType) { | ||||||
|  | 				field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) | ||||||
|  | 			} else if valuer, ok := v.(driver.Valuer); ok { | ||||||
|  | 				if v, err = valuer.Value(); err == nil { | ||||||
|  | 					return setter(value, v) | ||||||
|  | 				} | ||||||
|  | 			} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { | ||||||
|  | 				field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) | ||||||
|  | 			} else if reflectV.Kind() == reflect.Ptr { | ||||||
|  | 				return field.Set(value, reflectV.Elem().Interface()) | ||||||
|  | 			} else { | ||||||
|  | 				return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -387,6 +387,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] | |||||||
| 
 | 
 | ||||||
| 	_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) | 	_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) | ||||||
| 	column, values := ToQueryValues(relForeignKeys, foreignValues) | 	column, values := ToQueryValues(relForeignKeys, foreignValues) | ||||||
|  | 
 | ||||||
| 	conds = append(conds, clause.IN{Column: column, Values: values}) | 	conds = append(conds, clause.IN{Column: column, Values: values}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										33
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								statement.go
									
									
									
									
									
								
							| @ -278,6 +278,39 @@ func (stmt *Statement) Parse(value interface{}) (err error) { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (stmt *Statement) clone() *Statement { | ||||||
|  | 	newStmt := &Statement{ | ||||||
|  | 		DB:                   stmt.DB, | ||||||
|  | 		Table:                stmt.Table, | ||||||
|  | 		Model:                stmt.Model, | ||||||
|  | 		Dest:                 stmt.Dest, | ||||||
|  | 		ReflectValue:         stmt.ReflectValue, | ||||||
|  | 		Clauses:              map[string]clause.Clause{}, | ||||||
|  | 		Selects:              stmt.Selects, | ||||||
|  | 		Omits:                stmt.Omits, | ||||||
|  | 		Joins:                map[string][]interface{}{}, | ||||||
|  | 		Preloads:             map[string][]interface{}{}, | ||||||
|  | 		ConnPool:             stmt.ConnPool, | ||||||
|  | 		Schema:               stmt.Schema, | ||||||
|  | 		Context:              stmt.Context, | ||||||
|  | 		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for k, c := range stmt.Clauses { | ||||||
|  | 		newStmt.Clauses[k] = c | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for k, p := range stmt.Preloads { | ||||||
|  | 		newStmt.Preloads[k] = p | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for k, j := range stmt.Joins { | ||||||
|  | 		newStmt.Joins[k] = j | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return newStmt | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (stmt *Statement) reinit() { | func (stmt *Statement) reinit() { | ||||||
| 	// stmt.Table = ""
 | 	// stmt.Table = ""
 | ||||||
| 	// stmt.Model = nil
 | 	// stmt.Model = nil
 | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	CheckUser(t, user, user) | 	CheckUser(t, user, user) | ||||||
| 
 | 
 | ||||||
|  | 	// Find
 | ||||||
| 	var user2 User | 	var user2 User | ||||||
| 	DB.Find(&user2, "id = ?", user.ID) | 	DB.Find(&user2, "id = ?", user.ID) | ||||||
| 	DB.Model(&user2).Association("Company").Find(&user2.Company) | 	DB.Model(&user2).Association("Company").Find(&user2.Company) | ||||||
| @ -22,6 +23,7 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 	DB.Model(&user2).Association("Manager").Find(user2.Manager) | 	DB.Model(&user2).Association("Manager").Find(user2.Manager) | ||||||
| 	CheckUser(t, user2, user) | 	CheckUser(t, user2, user) | ||||||
| 
 | 
 | ||||||
|  | 	// Count
 | ||||||
| 	if count := DB.Model(&user).Association("Company").Count(); count != 1 { | 	if count := DB.Model(&user).Association("Company").Count(); count != 1 { | ||||||
| 		t.Errorf("invalid company count, got %v", count) | 		t.Errorf("invalid company count, got %v", count) | ||||||
| 	} | 	} | ||||||
| @ -29,4 +31,123 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 	if count := DB.Model(&user).Association("Manager").Count(); count != 1 { | 	if count := DB.Model(&user).Association("Manager").Count(); count != 1 { | ||||||
| 		t.Errorf("invalid manager count, got %v", count) | 		t.Errorf("invalid manager count, got %v", count) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	// Append
 | ||||||
|  | 	var company = Company{Name: "company-belongs-to-append"} | ||||||
|  | 	var manager = GetUser("manager-belongs-to-append", Config{}) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when append Company, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if company.ID == 0 { | ||||||
|  | 		t.Fatalf("Company's ID should be created") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when append Manager, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if manager.ID == 0 { | ||||||
|  | 		t.Fatalf("Manager's ID should be created") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user.Company = company | ||||||
|  | 	user.Manager = manager | ||||||
|  | 	user.CompanyID = &company.ID | ||||||
|  | 	user.ManagerID = &manager.ID | ||||||
|  | 	CheckUser(t, user2, user) | ||||||
|  | 
 | ||||||
|  | 	// Replace
 | ||||||
|  | 	var company2 = Company{Name: "company-belongs-to-replace"} | ||||||
|  | 	var manager2 = GetUser("manager-belongs-to-replace", Config{}) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when replace Company, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if company2.ID == 0 { | ||||||
|  | 		t.Fatalf("Company's ID should be created") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when replace Manager, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if manager2.ID == 0 { | ||||||
|  | 		t.Fatalf("Manager's ID should be created") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user.Company = company2 | ||||||
|  | 	user.Manager = manager2 | ||||||
|  | 	user.CompanyID = &company2.ID | ||||||
|  | 	user.ManagerID = &manager2.ID | ||||||
|  | 	CheckUser(t, user2, user) | ||||||
|  | 
 | ||||||
|  | 	// Delete
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when delete Company, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Company").Count(); count != 1 { | ||||||
|  | 		t.Errorf("Invalid company count after delete non-existing association, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when delete Company, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Company").Count(); count != 0 { | ||||||
|  | 		t.Errorf("Invalid company count after delete, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when delete Manager, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { | ||||||
|  | 		t.Errorf("Invalid manager count after delete non-existing association, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when delete Manager, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { | ||||||
|  | 		t.Errorf("Invalid manager count after delete, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Prepare Data
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when append Company, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { | ||||||
|  | 		t.Fatalf("Error happened when append Manager, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Company").Count(); count != 1 { | ||||||
|  | 		t.Errorf("Invalid company count after append, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { | ||||||
|  | 		t.Errorf("Invalid manager count after append, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Clear
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Company").Clear(); err != nil { | ||||||
|  | 		t.Errorf("Error happened when clear Company, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { | ||||||
|  | 		t.Errorf("Error happened when clear Manager, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Company").Count(); count != 0 { | ||||||
|  | 		t.Errorf("Invalid company count after clear, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { | ||||||
|  | 		t.Errorf("Invalid manager count after clear, got %v", count) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ func TestCount(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	var count3 int64 | 	var count3 int64 | ||||||
| 	if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { | 	if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { | ||||||
| 		t.Errorf("No error should happen when count with group, but got %v", err) | 		t.Errorf("Error happened when count with group, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if count3 != 2 { | 	if count3 != 2 { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu