Cleanup
This commit is contained in:
		
							parent
							
								
									7180bd0f27
								
							
						
					
					
						commit
						f0d514e330
					
				
							
								
								
									
										377
									
								
								association.go
									
									
									
									
									
								
							
							
						
						
									
										377
									
								
								association.go
									
									
									
									
									
								
							| @ -1,377 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // Association Mode contains some helper methods to handle relationship things easily.
 | ||||
| type Association struct { | ||||
| 	Error  error | ||||
| 	scope  *Scope | ||||
| 	column string | ||||
| 	field  *Field | ||||
| } | ||||
| 
 | ||||
| // Find find out all related associations
 | ||||
| func (association *Association) Find(value interface{}) *Association { | ||||
| 	association.scope.related(value, association.column) | ||||
| 	return association.setErr(association.scope.db.Error) | ||||
| } | ||||
| 
 | ||||
| // Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
 | ||||
| func (association *Association) Append(values ...interface{}) *Association { | ||||
| 	if association.Error != nil { | ||||
| 		return association | ||||
| 	} | ||||
| 
 | ||||
| 	if relationship := association.field.Relationship; relationship.Kind == "has_one" { | ||||
| 		return association.Replace(values...) | ||||
| 	} | ||||
| 	return association.saveAssociations(values...) | ||||
| } | ||||
| 
 | ||||
| // Replace replace current associations with new one
 | ||||
| func (association *Association) Replace(values ...interface{}) *Association { | ||||
| 	if association.Error != nil { | ||||
| 		return association | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		relationship = association.field.Relationship | ||||
| 		scope        = association.scope | ||||
| 		field        = association.field.Field | ||||
| 		newDB        = scope.NewDB() | ||||
| 	) | ||||
| 
 | ||||
| 	// Append new values
 | ||||
| 	association.field.Set(reflect.Zero(association.field.Field.Type())) | ||||
| 	association.saveAssociations(values...) | ||||
| 
 | ||||
| 	// Belongs To
 | ||||
| 	if relationship.Kind == "belongs_to" { | ||||
| 		// Set foreign key to be null when clearing value (length equals 0)
 | ||||
| 		if len(values) == 0 { | ||||
| 			// Set foreign key to be nil
 | ||||
| 			var foreignKeyMap = map[string]interface{}{} | ||||
| 			for _, foreignKey := range relationship.ForeignDBNames { | ||||
| 				foreignKeyMap[foreignKey] = nil | ||||
| 			} | ||||
| 			association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Polymorphic Relations
 | ||||
| 		if relationship.PolymorphicDBName != "" { | ||||
| 			newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) | ||||
| 		} | ||||
| 
 | ||||
| 		// Delete Relations except new created
 | ||||
| 		if len(values) > 0 { | ||||
| 			var associationForeignFieldNames, associationForeignDBNames []string | ||||
| 			if relationship.Kind == "many_to_many" { | ||||
| 				// if many to many relations, get association fields name from association foreign keys
 | ||||
| 				associationScope := scope.New(reflect.New(field.Type()).Interface()) | ||||
| 				for idx, dbName := range relationship.AssociationForeignFieldNames { | ||||
| 					if field, ok := associationScope.FieldByName(dbName); ok { | ||||
| 						associationForeignFieldNames = append(associationForeignFieldNames, field.Name) | ||||
| 						associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				// If has one/many relations, use primary keys
 | ||||
| 				for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { | ||||
| 					associationForeignFieldNames = append(associationForeignFieldNames, field.Name) | ||||
| 					associationForeignDBNames = append(associationForeignDBNames, field.DBName) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) | ||||
| 
 | ||||
| 			if len(newPrimaryKeys) > 0 { | ||||
| 				sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) | ||||
| 				newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if relationship.Kind == "many_to_many" { | ||||
| 			// if many to many relations, delete related relations from join table
 | ||||
| 			var sourceForeignFieldNames []string | ||||
| 
 | ||||
| 			for _, dbName := range relationship.ForeignFieldNames { | ||||
| 				if field, ok := scope.FieldByName(dbName); ok { | ||||
| 					sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { | ||||
| 				newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) | ||||
| 
 | ||||
| 				association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) | ||||
| 			} | ||||
| 		} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { | ||||
| 			// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
 | ||||
| 			var foreignKeyMap = map[string]interface{}{} | ||||
| 			for idx, foreignKey := range relationship.ForeignDBNames { | ||||
| 				foreignKeyMap[foreignKey] = nil | ||||
| 				if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { | ||||
| 					newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			fieldValue := reflect.New(association.field.Field.Type()).Interface() | ||||
| 			association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) | ||||
| 		} | ||||
| 	} | ||||
| 	return association | ||||
| } | ||||
| 
 | ||||
| // Delete remove relationship between source & passed arguments, but won't delete those arguments
 | ||||
| func (association *Association) Delete(values ...interface{}) *Association { | ||||
| 	if association.Error != nil { | ||||
| 		return association | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		relationship = association.field.Relationship | ||||
| 		scope        = association.scope | ||||
| 		field        = association.field.Field | ||||
| 		newDB        = scope.NewDB() | ||||
| 	) | ||||
| 
 | ||||
| 	if len(values) == 0 { | ||||
| 		return association | ||||
| 	} | ||||
| 
 | ||||
| 	var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string | ||||
| 	for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { | ||||
| 		deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) | ||||
| 		deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) | ||||
| 	} | ||||
| 
 | ||||
| 	deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) | ||||
| 
 | ||||
| 	if relationship.Kind == "many_to_many" { | ||||
| 		// source value's foreign keys
 | ||||
| 		for idx, foreignKey := range relationship.ForeignDBNames { | ||||
| 			if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { | ||||
| 				newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// get association's foreign fields name
 | ||||
| 		var associationScope = scope.New(reflect.New(field.Type()).Interface()) | ||||
| 		var associationForeignFieldNames []string | ||||
| 		for _, associationDBName := range relationship.AssociationForeignFieldNames { | ||||
| 			if field, ok := associationScope.FieldByName(associationDBName); ok { | ||||
| 				associationForeignFieldNames = append(associationForeignFieldNames, field.Name) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// association value's foreign keys
 | ||||
| 		deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) | ||||
| 		sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) | ||||
| 		newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) | ||||
| 
 | ||||
| 		association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) | ||||
| 	} else { | ||||
| 		var foreignKeyMap = map[string]interface{}{} | ||||
| 		for _, foreignKey := range relationship.ForeignDBNames { | ||||
| 			foreignKeyMap[foreignKey] = nil | ||||
| 		} | ||||
| 
 | ||||
| 		if relationship.Kind == "belongs_to" { | ||||
| 			// find with deleting relation's foreign keys
 | ||||
| 			primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) | ||||
| 			newDB = newDB.Where( | ||||
| 				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), | ||||
| 				toQueryValues(primaryKeys)..., | ||||
| 			) | ||||
| 
 | ||||
| 			// set foreign key to be null if there are some records affected
 | ||||
| 			modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() | ||||
| 			if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { | ||||
| 				if results.RowsAffected > 0 { | ||||
| 					scope.updatedAttrsWithValues(foreignKeyMap) | ||||
| 				} | ||||
| 			} else { | ||||
| 				association.setErr(results.Error) | ||||
| 			} | ||||
| 		} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { | ||||
| 			// find all relations
 | ||||
| 			primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) | ||||
| 			newDB = newDB.Where( | ||||
| 				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), | ||||
| 				toQueryValues(primaryKeys)..., | ||||
| 			) | ||||
| 
 | ||||
| 			// only include those deleting relations
 | ||||
| 			newDB = newDB.Where( | ||||
| 				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), | ||||
| 				toQueryValues(deletingPrimaryKeys)..., | ||||
| 			) | ||||
| 
 | ||||
| 			// set matched relation's foreign key to be null
 | ||||
| 			fieldValue := reflect.New(association.field.Field.Type()).Interface() | ||||
| 			association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Remove deleted records from source's field
 | ||||
| 	if association.Error == nil { | ||||
| 		if field.Kind() == reflect.Slice { | ||||
| 			leftValues := reflect.Zero(field.Type()) | ||||
| 
 | ||||
| 			for i := 0; i < field.Len(); i++ { | ||||
| 				reflectValue := field.Index(i) | ||||
| 				primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] | ||||
| 				var isDeleted = false | ||||
| 				for _, pk := range deletingPrimaryKeys { | ||||
| 					if equalAsString(primaryKey, pk) { | ||||
| 						isDeleted = true | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
| 				if !isDeleted { | ||||
| 					leftValues = reflect.Append(leftValues, reflectValue) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			association.field.Set(leftValues) | ||||
| 		} else if field.Kind() == reflect.Struct { | ||||
| 			primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] | ||||
| 			for _, pk := range deletingPrimaryKeys { | ||||
| 				if equalAsString(primaryKey, pk) { | ||||
| 					association.field.Set(reflect.Zero(field.Type())) | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return association | ||||
| } | ||||
| 
 | ||||
| // Clear remove relationship between source & current associations, won't delete those associations
 | ||||
| func (association *Association) Clear() *Association { | ||||
| 	return association.Replace() | ||||
| } | ||||
| 
 | ||||
| // Count return the count of current associations
 | ||||
| func (association *Association) Count() int { | ||||
| 	var ( | ||||
| 		count        = 0 | ||||
| 		relationship = association.field.Relationship | ||||
| 		scope        = association.scope | ||||
| 		fieldValue   = association.field.Field.Interface() | ||||
| 		query        = scope.DB() | ||||
| 	) | ||||
| 
 | ||||
| 	switch relationship.Kind { | ||||
| 	case "many_to_many": | ||||
| 		query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) | ||||
| 	case "has_many", "has_one": | ||||
| 		primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) | ||||
| 		query = query.Where( | ||||
| 			fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), | ||||
| 			toQueryValues(primaryKeys)..., | ||||
| 		) | ||||
| 	case "belongs_to": | ||||
| 		primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) | ||||
| 		query = query.Where( | ||||
| 			fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), | ||||
| 			toQueryValues(primaryKeys)..., | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	if relationship.PolymorphicType != "" { | ||||
| 		query = query.Where( | ||||
| 			fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), | ||||
| 			relationship.PolymorphicValue, | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := query.Model(fieldValue).Count(&count).Error; err != nil { | ||||
| 		association.Error = err | ||||
| 	} | ||||
| 	return count | ||||
| } | ||||
| 
 | ||||
| // saveAssociations save passed values as associations
 | ||||
| func (association *Association) saveAssociations(values ...interface{}) *Association { | ||||
| 	var ( | ||||
| 		scope        = association.scope | ||||
| 		field        = association.field | ||||
| 		relationship = field.Relationship | ||||
| 	) | ||||
| 
 | ||||
| 	saveAssociation := func(reflectValue reflect.Value) { | ||||
| 		// value has to been pointer
 | ||||
| 		if reflectValue.Kind() != reflect.Ptr { | ||||
| 			reflectPtr := reflect.New(reflectValue.Type()) | ||||
| 			reflectPtr.Elem().Set(reflectValue) | ||||
| 			reflectValue = reflectPtr | ||||
| 		} | ||||
| 
 | ||||
| 		// value has to been saved for many2many
 | ||||
| 		if relationship.Kind == "many_to_many" { | ||||
| 			if scope.New(reflectValue.Interface()).PrimaryKeyZero() { | ||||
| 				association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// Assign Fields
 | ||||
| 		var fieldType = field.Field.Type() | ||||
| 		var setFieldBackToValue, setSliceFieldBackToValue bool | ||||
| 		if reflectValue.Type().AssignableTo(fieldType) { | ||||
| 			field.Set(reflectValue) | ||||
| 		} else if reflectValue.Type().Elem().AssignableTo(fieldType) { | ||||
| 			// if field's type is struct, then need to set value back to argument after save
 | ||||
| 			setFieldBackToValue = true | ||||
| 			field.Set(reflectValue.Elem()) | ||||
| 		} else if fieldType.Kind() == reflect.Slice { | ||||
| 			if reflectValue.Type().AssignableTo(fieldType.Elem()) { | ||||
| 				field.Set(reflect.Append(field.Field, reflectValue)) | ||||
| 			} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { | ||||
| 				// if field's type is slice of struct, then need to set value back to argument after save
 | ||||
| 				setSliceFieldBackToValue = true | ||||
| 				field.Set(reflect.Append(field.Field, reflectValue.Elem())) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if relationship.Kind == "many_to_many" { | ||||
| 			association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) | ||||
| 		} else { | ||||
| 			association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) | ||||
| 
 | ||||
| 			if setFieldBackToValue { | ||||
| 				reflectValue.Elem().Set(field.Field) | ||||
| 			} else if setSliceFieldBackToValue { | ||||
| 				reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, value := range values { | ||||
| 		reflectValue := reflect.ValueOf(value) | ||||
| 		indirectReflectValue := reflect.Indirect(reflectValue) | ||||
| 		if indirectReflectValue.Kind() == reflect.Struct { | ||||
| 			saveAssociation(reflectValue) | ||||
| 		} else if indirectReflectValue.Kind() == reflect.Slice { | ||||
| 			for i := 0; i < indirectReflectValue.Len(); i++ { | ||||
| 				saveAssociation(indirectReflectValue.Index(i)) | ||||
| 			} | ||||
| 		} else { | ||||
| 			association.setErr(errors.New("invalid value type")) | ||||
| 		} | ||||
| 	} | ||||
| 	return association | ||||
| } | ||||
| 
 | ||||
| // setErr set error when the error is not nil. And return Association.
 | ||||
| func (association *Association) setErr(err error) *Association { | ||||
| 	if err != nil { | ||||
| 		association.Error = err | ||||
| 	} | ||||
| 	return association | ||||
| } | ||||
							
								
								
									
										1050
									
								
								association_test.go
									
									
									
									
									
								
							
							
						
						
									
										1050
									
								
								association_test.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										250
									
								
								callback.go
									
									
									
									
									
								
							
							
						
						
									
										250
									
								
								callback.go
									
									
									
									
									
								
							| @ -1,250 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import "fmt" | ||||
| 
 | ||||
| // DefaultCallback default callbacks defined by gorm
 | ||||
| var DefaultCallback = &Callback{logger: nopLogger{}} | ||||
| 
 | ||||
| // Callback is a struct that contains all CRUD callbacks
 | ||||
| //   Field `creates` contains callbacks will be call when creating object
 | ||||
| //   Field `updates` contains callbacks will be call when updating object
 | ||||
| //   Field `deletes` contains callbacks will be call when deleting object
 | ||||
| //   Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
 | ||||
| //   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
 | ||||
| //   Field `processors` contains all callback processors, will be used to generate above callbacks in order
 | ||||
| type Callback struct { | ||||
| 	logger     logger | ||||
| 	creates    []*func(scope *Scope) | ||||
| 	updates    []*func(scope *Scope) | ||||
| 	deletes    []*func(scope *Scope) | ||||
| 	queries    []*func(scope *Scope) | ||||
| 	rowQueries []*func(scope *Scope) | ||||
| 	processors []*CallbackProcessor | ||||
| } | ||||
| 
 | ||||
| // CallbackProcessor contains callback informations
 | ||||
| type CallbackProcessor struct { | ||||
| 	logger    logger | ||||
| 	name      string              // current callback's name
 | ||||
| 	before    string              // register current callback before a callback
 | ||||
| 	after     string              // register current callback after a callback
 | ||||
| 	replace   bool                // replace callbacks with same name
 | ||||
| 	remove    bool                // delete callbacks with same name
 | ||||
| 	kind      string              // callback type: create, update, delete, query, row_query
 | ||||
| 	processor *func(scope *Scope) // callback handler
 | ||||
| 	parent    *Callback | ||||
| } | ||||
| 
 | ||||
| func (c *Callback) clone(logger logger) *Callback { | ||||
| 	return &Callback{ | ||||
| 		logger:     logger, | ||||
| 		creates:    c.creates, | ||||
| 		updates:    c.updates, | ||||
| 		deletes:    c.deletes, | ||||
| 		queries:    c.queries, | ||||
| 		rowQueries: c.rowQueries, | ||||
| 		processors: c.processors, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Create could be used to register callbacks for creating object
 | ||||
| //     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
 | ||||
| //       // business logic
 | ||||
| //       ...
 | ||||
| //
 | ||||
| //       // set error if some thing wrong happened, will rollback the creating
 | ||||
| //       scope.Err(errors.New("error"))
 | ||||
| //     })
 | ||||
| func (c *Callback) Create() *CallbackProcessor { | ||||
| 	return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} | ||||
| } | ||||
| 
 | ||||
| // Update could be used to register callbacks for updating object, refer `Create` for usage
 | ||||
| func (c *Callback) Update() *CallbackProcessor { | ||||
| 	return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} | ||||
| } | ||||
| 
 | ||||
| // Delete could be used to register callbacks for deleting object, refer `Create` for usage
 | ||||
| func (c *Callback) Delete() *CallbackProcessor { | ||||
| 	return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} | ||||
| } | ||||
| 
 | ||||
| // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
 | ||||
| // Refer `Create` for usage
 | ||||
| func (c *Callback) Query() *CallbackProcessor { | ||||
| 	return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} | ||||
| } | ||||
| 
 | ||||
| // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
 | ||||
| func (c *Callback) RowQuery() *CallbackProcessor { | ||||
| 	return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} | ||||
| } | ||||
| 
 | ||||
| // After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
 | ||||
| func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { | ||||
| 	cp.after = callbackName | ||||
| 	return cp | ||||
| } | ||||
| 
 | ||||
| // Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
 | ||||
| func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { | ||||
| 	cp.before = callbackName | ||||
| 	return cp | ||||
| } | ||||
| 
 | ||||
| // Register a new callback, refer `Callbacks.Create`
 | ||||
| func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { | ||||
| 	if cp.kind == "row_query" { | ||||
| 		if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { | ||||
| 			cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) | ||||
| 			cp.before = "gorm:row_query" | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) | ||||
| 	cp.name = callbackName | ||||
| 	cp.processor = &callback | ||||
| 	cp.parent.processors = append(cp.parent.processors, cp) | ||||
| 	cp.parent.reorder() | ||||
| } | ||||
| 
 | ||||
| // Remove a registered callback
 | ||||
| //     db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
 | ||||
| func (cp *CallbackProcessor) Remove(callbackName string) { | ||||
| 	cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) | ||||
| 	cp.name = callbackName | ||||
| 	cp.remove = true | ||||
| 	cp.parent.processors = append(cp.parent.processors, cp) | ||||
| 	cp.parent.reorder() | ||||
| } | ||||
| 
 | ||||
| // Replace a registered callback with new callback
 | ||||
| //     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
 | ||||
| //		   scope.SetColumn("CreatedAt", now)
 | ||||
| //		   scope.SetColumn("UpdatedAt", now)
 | ||||
| //     })
 | ||||
| func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { | ||||
| 	cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) | ||||
| 	cp.name = callbackName | ||||
| 	cp.processor = &callback | ||||
| 	cp.replace = true | ||||
| 	cp.parent.processors = append(cp.parent.processors, cp) | ||||
| 	cp.parent.reorder() | ||||
| } | ||||
| 
 | ||||
| // Get registered callback
 | ||||
| //    db.Callback().Create().Get("gorm:create")
 | ||||
| func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { | ||||
| 	for _, p := range cp.parent.processors { | ||||
| 		if p.name == callbackName && p.kind == cp.kind { | ||||
| 			if p.remove { | ||||
| 				callback = nil | ||||
| 			} else { | ||||
| 				callback = *p.processor | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // getRIndex get right index from string slice
 | ||||
| func getRIndex(strs []string, str string) int { | ||||
| 	for i := len(strs) - 1; i >= 0; i-- { | ||||
| 		if strs[i] == str { | ||||
| 			return i | ||||
| 		} | ||||
| 	} | ||||
| 	return -1 | ||||
| } | ||||
| 
 | ||||
| // sortProcessors sort callback processors based on its before, after, remove, replace
 | ||||
| func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { | ||||
| 	var ( | ||||
| 		allNames, sortedNames []string | ||||
| 		sortCallbackProcessor func(c *CallbackProcessor) | ||||
| 	) | ||||
| 
 | ||||
| 	for _, cp := range cps { | ||||
| 		// show warning message the callback name already exists
 | ||||
| 		if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { | ||||
| 			cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) | ||||
| 		} | ||||
| 		allNames = append(allNames, cp.name) | ||||
| 	} | ||||
| 
 | ||||
| 	sortCallbackProcessor = func(c *CallbackProcessor) { | ||||
| 		if getRIndex(sortedNames, c.name) == -1 { // if not sorted
 | ||||
| 			if c.before != "" { // if defined before callback
 | ||||
| 				if index := getRIndex(sortedNames, c.before); index != -1 { | ||||
| 					// if before callback already sorted, append current callback just after it
 | ||||
| 					sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) | ||||
| 				} else if index := getRIndex(allNames, c.before); index != -1 { | ||||
| 					// if before callback exists but haven't sorted, append current callback to last
 | ||||
| 					sortedNames = append(sortedNames, c.name) | ||||
| 					sortCallbackProcessor(cps[index]) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if c.after != "" { // if defined after callback
 | ||||
| 				if index := getRIndex(sortedNames, c.after); index != -1 { | ||||
| 					// if after callback already sorted, append current callback just before it
 | ||||
| 					sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) | ||||
| 				} else if index := getRIndex(allNames, c.after); index != -1 { | ||||
| 					// if after callback exists but haven't sorted
 | ||||
| 					cp := cps[index] | ||||
| 					// set after callback's before callback to current callback
 | ||||
| 					if cp.before == "" { | ||||
| 						cp.before = c.name | ||||
| 					} | ||||
| 					sortCallbackProcessor(cp) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// if current callback haven't been sorted, append it to last
 | ||||
| 			if getRIndex(sortedNames, c.name) == -1 { | ||||
| 				sortedNames = append(sortedNames, c.name) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, cp := range cps { | ||||
| 		sortCallbackProcessor(cp) | ||||
| 	} | ||||
| 
 | ||||
| 	var sortedFuncs []*func(scope *Scope) | ||||
| 	for _, name := range sortedNames { | ||||
| 		if index := getRIndex(allNames, name); !cps[index].remove { | ||||
| 			sortedFuncs = append(sortedFuncs, cps[index].processor) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return sortedFuncs | ||||
| } | ||||
| 
 | ||||
| // reorder all registered processors, and reset CRUD callbacks
 | ||||
| func (c *Callback) reorder() { | ||||
| 	var creates, updates, deletes, queries, rowQueries []*CallbackProcessor | ||||
| 
 | ||||
| 	for _, processor := range c.processors { | ||||
| 		if processor.name != "" { | ||||
| 			switch processor.kind { | ||||
| 			case "create": | ||||
| 				creates = append(creates, processor) | ||||
| 			case "update": | ||||
| 				updates = append(updates, processor) | ||||
| 			case "delete": | ||||
| 				deletes = append(deletes, processor) | ||||
| 			case "query": | ||||
| 				queries = append(queries, processor) | ||||
| 			case "row_query": | ||||
| 				rowQueries = append(rowQueries, processor) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	c.creates = sortProcessors(creates) | ||||
| 	c.updates = sortProcessors(updates) | ||||
| 	c.deletes = sortProcessors(deletes) | ||||
| 	c.queries = sortProcessors(queries) | ||||
| 	c.rowQueries = sortProcessors(rowQueries) | ||||
| } | ||||
| @ -1,197 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Define callbacks for creating
 | ||||
| func init() { | ||||
| 	DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:create", createCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) | ||||
| 	DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) | ||||
| } | ||||
| 
 | ||||
| // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
 | ||||
| func beforeCreateCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("BeforeSave") | ||||
| 	} | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("BeforeCreate") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
 | ||||
| func updateTimeStampForCreateCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		now := scope.db.nowFunc() | ||||
| 
 | ||||
| 		if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { | ||||
| 			if createdAtField.IsBlank { | ||||
| 				createdAtField.Set(now) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { | ||||
| 			if updatedAtField.IsBlank { | ||||
| 				updatedAtField.Set(now) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // createCallback the callback used to insert data into database
 | ||||
| func createCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		defer scope.trace(NowFunc()) | ||||
| 
 | ||||
| 		var ( | ||||
| 			columns, placeholders        []string | ||||
| 			blankColumnsWithDefaultValue []string | ||||
| 		) | ||||
| 
 | ||||
| 		for _, field := range scope.Fields() { | ||||
| 			if scope.changeableField(field) { | ||||
| 				if field.IsNormal && !field.IsIgnored { | ||||
| 					if field.IsBlank && field.HasDefaultValue { | ||||
| 						blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) | ||||
| 						scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) | ||||
| 					} else if !field.IsPrimaryKey || !field.IsBlank { | ||||
| 						columns = append(columns, scope.Quote(field.DBName)) | ||||
| 						placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) | ||||
| 					} | ||||
| 				} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { | ||||
| 					for _, foreignKey := range field.Relationship.ForeignDBNames { | ||||
| 						if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { | ||||
| 							columns = append(columns, scope.Quote(foreignField.DBName)) | ||||
| 							placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		var ( | ||||
| 			returningColumn = "*" | ||||
| 			quotedTableName = scope.QuotedTableName() | ||||
| 			primaryField    = scope.PrimaryField() | ||||
| 			extraOption     string | ||||
| 			insertModifier  string | ||||
| 		) | ||||
| 
 | ||||
| 		if str, ok := scope.Get("gorm:insert_option"); ok { | ||||
| 			extraOption = fmt.Sprint(str) | ||||
| 		} | ||||
| 		if str, ok := scope.Get("gorm:insert_modifier"); ok { | ||||
| 			insertModifier = strings.ToUpper(fmt.Sprint(str)) | ||||
| 			if insertModifier == "INTO" { | ||||
| 				insertModifier = "" | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if primaryField != nil { | ||||
| 			returningColumn = scope.Quote(primaryField.DBName) | ||||
| 		} | ||||
| 
 | ||||
| 		lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) | ||||
| 		var lastInsertIDReturningSuffix string | ||||
| 		if lastInsertIDOutputInterstitial == "" { | ||||
| 			lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) | ||||
| 		} | ||||
| 
 | ||||
| 		if len(columns) == 0 { | ||||
| 			scope.Raw(fmt.Sprintf( | ||||
| 				"INSERT%v INTO %v %v%v%v", | ||||
| 				addExtraSpaceIfExist(insertModifier), | ||||
| 				quotedTableName, | ||||
| 				scope.Dialect().DefaultValueStr(), | ||||
| 				addExtraSpaceIfExist(extraOption), | ||||
| 				addExtraSpaceIfExist(lastInsertIDReturningSuffix), | ||||
| 			)) | ||||
| 		} else { | ||||
| 			scope.Raw(fmt.Sprintf( | ||||
| 				"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", | ||||
| 				addExtraSpaceIfExist(insertModifier), | ||||
| 				scope.QuotedTableName(), | ||||
| 				strings.Join(columns, ","), | ||||
| 				addExtraSpaceIfExist(lastInsertIDOutputInterstitial), | ||||
| 				strings.Join(placeholders, ","), | ||||
| 				addExtraSpaceIfExist(extraOption), | ||||
| 				addExtraSpaceIfExist(lastInsertIDReturningSuffix), | ||||
| 			)) | ||||
| 		} | ||||
| 
 | ||||
| 		// execute create sql: no primaryField
 | ||||
| 		if primaryField == nil { | ||||
| 			if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { | ||||
| 				// set rows affected count
 | ||||
| 				scope.db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 				// set primary value to primary field
 | ||||
| 				if primaryField != nil && primaryField.IsBlank { | ||||
| 					if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { | ||||
| 						scope.Err(primaryField.Set(primaryValue)) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// execute create sql: lastInsertID implemention for majority of dialects
 | ||||
| 		if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { | ||||
| 			if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { | ||||
| 				// set rows affected count
 | ||||
| 				scope.db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 				// set primary value to primary field
 | ||||
| 				if primaryField != nil && primaryField.IsBlank { | ||||
| 					if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { | ||||
| 						scope.Err(primaryField.Set(primaryValue)) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
 | ||||
| 		if primaryField.Field.CanAddr() { | ||||
| 			if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { | ||||
| 				primaryField.IsBlank = false | ||||
| 				scope.db.RowsAffected = 1 | ||||
| 			} | ||||
| 		} else { | ||||
| 			scope.Err(ErrUnaddressable) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
 | ||||
| func forceReloadAfterCreateCallback(scope *Scope) { | ||||
| 	if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { | ||||
| 		db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) | ||||
| 		for _, field := range scope.Fields() { | ||||
| 			if field.IsPrimaryKey && !field.IsBlank { | ||||
| 				db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) | ||||
| 			} | ||||
| 		} | ||||
| 		db.Scan(scope.Value) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
 | ||||
| func afterCreateCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("AfterCreate") | ||||
| 	} | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("AfterSave") | ||||
| 	} | ||||
| } | ||||
| @ -1,63 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| // Define callbacks for deleting
 | ||||
| func init() { | ||||
| 	DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) | ||||
| 	DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) | ||||
| 	DefaultCallback.Delete().Register("gorm:delete", deleteCallback) | ||||
| 	DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) | ||||
| 	DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) | ||||
| } | ||||
| 
 | ||||
| // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
 | ||||
| func beforeDeleteCallback(scope *Scope) { | ||||
| 	if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { | ||||
| 		scope.Err(errors.New("missing WHERE clause while deleting")) | ||||
| 		return | ||||
| 	} | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("BeforeDelete") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
 | ||||
| func deleteCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		var extraOption string | ||||
| 		if str, ok := scope.Get("gorm:delete_option"); ok { | ||||
| 			extraOption = fmt.Sprint(str) | ||||
| 		} | ||||
| 
 | ||||
| 		deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") | ||||
| 
 | ||||
| 		if !scope.Search.Unscoped && hasDeletedAtField { | ||||
| 			scope.Raw(fmt.Sprintf( | ||||
| 				"UPDATE %v SET %v=%v%v%v", | ||||
| 				scope.QuotedTableName(), | ||||
| 				scope.Quote(deletedAtField.DBName), | ||||
| 				scope.AddToVars(scope.db.nowFunc()), | ||||
| 				addExtraSpaceIfExist(scope.CombinedConditionSql()), | ||||
| 				addExtraSpaceIfExist(extraOption), | ||||
| 			)).Exec() | ||||
| 		} else { | ||||
| 			scope.Raw(fmt.Sprintf( | ||||
| 				"DELETE FROM %v%v%v", | ||||
| 				scope.QuotedTableName(), | ||||
| 				addExtraSpaceIfExist(scope.CombinedConditionSql()), | ||||
| 				addExtraSpaceIfExist(extraOption), | ||||
| 			)).Exec() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // afterDeleteCallback will invoke `AfterDelete` method after deleting
 | ||||
| func afterDeleteCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("AfterDelete") | ||||
| 	} | ||||
| } | ||||
| @ -1,109 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // Define callbacks for querying
 | ||||
| func init() { | ||||
| 	DefaultCallback.Query().Register("gorm:query", queryCallback) | ||||
| 	DefaultCallback.Query().Register("gorm:preload", preloadCallback) | ||||
| 	DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) | ||||
| } | ||||
| 
 | ||||
| // queryCallback used to query data from database
 | ||||
| func queryCallback(scope *Scope) { | ||||
| 	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	//we are only preloading relations, dont touch base model
 | ||||
| 	if _, skip := scope.InstanceGet("gorm:only_preload"); skip { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	defer scope.trace(NowFunc()) | ||||
| 
 | ||||
| 	var ( | ||||
| 		isSlice, isPtr bool | ||||
| 		resultType     reflect.Type | ||||
| 		results        = scope.IndirectValue() | ||||
| 	) | ||||
| 
 | ||||
| 	if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { | ||||
| 		if primaryField := scope.PrimaryField(); primaryField != nil { | ||||
| 			scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if value, ok := scope.Get("gorm:query_destination"); ok { | ||||
| 		results = indirect(reflect.ValueOf(value)) | ||||
| 	} | ||||
| 
 | ||||
| 	if kind := results.Kind(); kind == reflect.Slice { | ||||
| 		isSlice = true | ||||
| 		resultType = results.Type().Elem() | ||||
| 		results.Set(reflect.MakeSlice(results.Type(), 0, 0)) | ||||
| 
 | ||||
| 		if resultType.Kind() == reflect.Ptr { | ||||
| 			isPtr = true | ||||
| 			resultType = resultType.Elem() | ||||
| 		} | ||||
| 	} else if kind != reflect.Struct { | ||||
| 		scope.Err(errors.New("unsupported destination, should be slice or struct")) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	scope.prepareQuerySQL() | ||||
| 
 | ||||
| 	if !scope.HasError() { | ||||
| 		scope.db.RowsAffected = 0 | ||||
| 
 | ||||
| 		if str, ok := scope.Get("gorm:query_hint"); ok { | ||||
| 			scope.SQL = fmt.Sprint(str) + scope.SQL | ||||
| 		} | ||||
| 
 | ||||
| 		if str, ok := scope.Get("gorm:query_option"); ok { | ||||
| 			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) | ||||
| 		} | ||||
| 
 | ||||
| 		if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { | ||||
| 			defer rows.Close() | ||||
| 
 | ||||
| 			columns, _ := rows.Columns() | ||||
| 			for rows.Next() { | ||||
| 				scope.db.RowsAffected++ | ||||
| 
 | ||||
| 				elem := results | ||||
| 				if isSlice { | ||||
| 					elem = reflect.New(resultType).Elem() | ||||
| 				} | ||||
| 
 | ||||
| 				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) | ||||
| 
 | ||||
| 				if isSlice { | ||||
| 					if isPtr { | ||||
| 						results.Set(reflect.Append(results, elem.Addr())) | ||||
| 					} else { | ||||
| 						results.Set(reflect.Append(results, elem)) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if err := rows.Err(); err != nil { | ||||
| 				scope.Err(err) | ||||
| 			} else if scope.db.RowsAffected == 0 && !isSlice { | ||||
| 				scope.Err(ErrRecordNotFound) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // afterQueryCallback will invoke `AfterFind` method after querying
 | ||||
| func afterQueryCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		scope.CallMethod("AfterFind") | ||||
| 	} | ||||
| } | ||||
| @ -1,410 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // preloadCallback used to preload associations
 | ||||
| func preloadCallback(scope *Scope) { | ||||
| 	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if ap, ok := scope.Get("gorm:auto_preload"); ok { | ||||
| 		// If gorm:auto_preload IS NOT a bool then auto preload.
 | ||||
| 		// Else if it IS a bool, use the value
 | ||||
| 		if apb, ok := ap.(bool); !ok { | ||||
| 			autoPreload(scope) | ||||
| 		} else if apb { | ||||
| 			autoPreload(scope) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if scope.Search.preload == nil || scope.HasError() { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		preloadedMap = map[string]bool{} | ||||
| 		fields       = scope.Fields() | ||||
| 	) | ||||
| 
 | ||||
| 	for _, preload := range scope.Search.preload { | ||||
| 		var ( | ||||
| 			preloadFields = strings.Split(preload.schema, ".") | ||||
| 			currentScope  = scope | ||||
| 			currentFields = fields | ||||
| 		) | ||||
| 
 | ||||
| 		for idx, preloadField := range preloadFields { | ||||
| 			var currentPreloadConditions []interface{} | ||||
| 
 | ||||
| 			if currentScope == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			// if not preloaded
 | ||||
| 			if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { | ||||
| 
 | ||||
| 				// assign search conditions to last preload
 | ||||
| 				if idx == len(preloadFields)-1 { | ||||
| 					currentPreloadConditions = preload.conditions | ||||
| 				} | ||||
| 
 | ||||
| 				for _, field := range currentFields { | ||||
| 					if field.Name != preloadField || field.Relationship == nil { | ||||
| 						continue | ||||
| 					} | ||||
| 
 | ||||
| 					switch field.Relationship.Kind { | ||||
| 					case "has_one": | ||||
| 						currentScope.handleHasOnePreload(field, currentPreloadConditions) | ||||
| 					case "has_many": | ||||
| 						currentScope.handleHasManyPreload(field, currentPreloadConditions) | ||||
| 					case "belongs_to": | ||||
| 						currentScope.handleBelongsToPreload(field, currentPreloadConditions) | ||||
| 					case "many_to_many": | ||||
| 						currentScope.handleManyToManyPreload(field, currentPreloadConditions) | ||||
| 					default: | ||||
| 						scope.Err(errors.New("unsupported relation")) | ||||
| 					} | ||||
| 
 | ||||
| 					preloadedMap[preloadKey] = true | ||||
| 					break | ||||
| 				} | ||||
| 
 | ||||
| 				if !preloadedMap[preloadKey] { | ||||
| 					scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// preload next level
 | ||||
| 			if idx < len(preloadFields)-1 { | ||||
| 				currentScope = currentScope.getColumnAsScope(preloadField) | ||||
| 				if currentScope != nil { | ||||
| 					currentFields = currentScope.Fields() | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func autoPreload(scope *Scope) { | ||||
| 	for _, field := range scope.Fields() { | ||||
| 		if field.Relationship == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if val, ok := field.TagSettingsGet("PRELOAD"); ok { | ||||
| 			if preload, err := strconv.ParseBool(val); err != nil { | ||||
| 				scope.Err(errors.New("invalid preload option")) | ||||
| 				return | ||||
| 			} else if !preload { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		scope.Search.Preload(field.Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { | ||||
| 	var ( | ||||
| 		preloadDB         = scope.NewDB() | ||||
| 		preloadConditions []interface{} | ||||
| 	) | ||||
| 
 | ||||
| 	for _, condition := range conditions { | ||||
| 		if scopes, ok := condition.(func(*DB) *DB); ok { | ||||
| 			preloadDB = scopes(preloadDB) | ||||
| 		} else { | ||||
| 			preloadConditions = append(preloadConditions, condition) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return preloadDB, preloadConditions | ||||
| } | ||||
| 
 | ||||
| // handleHasOnePreload used to preload has one associations
 | ||||
| func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { | ||||
| 	relation := field.Relationship | ||||
| 
 | ||||
| 	// get relations's primary keys
 | ||||
| 	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) | ||||
| 	if len(primaryKeys) == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// preload conditions
 | ||||
| 	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | ||||
| 
 | ||||
| 	// find relations
 | ||||
| 	query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) | ||||
| 	values := toQueryValues(primaryKeys) | ||||
| 	if relation.PolymorphicType != "" { | ||||
| 		query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) | ||||
| 		values = append(values, relation.PolymorphicValue) | ||||
| 	} | ||||
| 
 | ||||
| 	results := makeSlice(field.Struct.Type) | ||||
| 	scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) | ||||
| 
 | ||||
| 	// assign find results
 | ||||
| 	var ( | ||||
| 		resultsValue       = indirect(reflect.ValueOf(results)) | ||||
| 		indirectScopeValue = scope.IndirectValue() | ||||
| 	) | ||||
| 
 | ||||
| 	if indirectScopeValue.Kind() == reflect.Slice { | ||||
| 		foreignValuesToResults := make(map[string]reflect.Value) | ||||
| 		for i := 0; i < resultsValue.Len(); i++ { | ||||
| 			result := resultsValue.Index(i) | ||||
| 			foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) | ||||
| 			foreignValuesToResults[foreignValues] = result | ||||
| 		} | ||||
| 		for j := 0; j < indirectScopeValue.Len(); j++ { | ||||
| 			indirectValue := indirect(indirectScopeValue.Index(j)) | ||||
| 			valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) | ||||
| 			if result, found := foreignValuesToResults[valueString]; found { | ||||
| 				indirectValue.FieldByName(field.Name).Set(result) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		for i := 0; i < resultsValue.Len(); i++ { | ||||
| 			result := resultsValue.Index(i) | ||||
| 			scope.Err(field.Set(result)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // handleHasManyPreload used to preload has many associations
 | ||||
| func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { | ||||
| 	relation := field.Relationship | ||||
| 
 | ||||
| 	// get relations's primary keys
 | ||||
| 	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) | ||||
| 	if len(primaryKeys) == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// preload conditions
 | ||||
| 	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | ||||
| 
 | ||||
| 	// find relations
 | ||||
| 	query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) | ||||
| 	values := toQueryValues(primaryKeys) | ||||
| 	if relation.PolymorphicType != "" { | ||||
| 		query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) | ||||
| 		values = append(values, relation.PolymorphicValue) | ||||
| 	} | ||||
| 
 | ||||
| 	results := makeSlice(field.Struct.Type) | ||||
| 	scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) | ||||
| 
 | ||||
| 	// assign find results
 | ||||
| 	var ( | ||||
| 		resultsValue       = indirect(reflect.ValueOf(results)) | ||||
| 		indirectScopeValue = scope.IndirectValue() | ||||
| 	) | ||||
| 
 | ||||
| 	if indirectScopeValue.Kind() == reflect.Slice { | ||||
| 		preloadMap := make(map[string][]reflect.Value) | ||||
| 		for i := 0; i < resultsValue.Len(); i++ { | ||||
| 			result := resultsValue.Index(i) | ||||
| 			foreignValues := getValueFromFields(result, relation.ForeignFieldNames) | ||||
| 			preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) | ||||
| 		} | ||||
| 
 | ||||
| 		for j := 0; j < indirectScopeValue.Len(); j++ { | ||||
| 			object := indirect(indirectScopeValue.Index(j)) | ||||
| 			objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) | ||||
| 			f := object.FieldByName(field.Name) | ||||
| 			if results, ok := preloadMap[toString(objectRealValue)]; ok { | ||||
| 				f.Set(reflect.Append(f, results...)) | ||||
| 			} else { | ||||
| 				f.Set(reflect.MakeSlice(f.Type(), 0, 0)) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		scope.Err(field.Set(resultsValue)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // handleBelongsToPreload used to preload belongs to associations
 | ||||
| func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { | ||||
| 	relation := field.Relationship | ||||
| 
 | ||||
| 	// preload conditions
 | ||||
| 	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | ||||
| 
 | ||||
| 	// get relations's primary keys
 | ||||
| 	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) | ||||
| 	if len(primaryKeys) == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// find relations
 | ||||
| 	results := makeSlice(field.Struct.Type) | ||||
| 	scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) | ||||
| 
 | ||||
| 	// assign find results
 | ||||
| 	var ( | ||||
| 		resultsValue       = indirect(reflect.ValueOf(results)) | ||||
| 		indirectScopeValue = scope.IndirectValue() | ||||
| 	) | ||||
| 
 | ||||
| 	foreignFieldToObjects := make(map[string][]*reflect.Value) | ||||
| 	if indirectScopeValue.Kind() == reflect.Slice { | ||||
| 		for j := 0; j < indirectScopeValue.Len(); j++ { | ||||
| 			object := indirect(indirectScopeValue.Index(j)) | ||||
| 			valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) | ||||
| 			foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for i := 0; i < resultsValue.Len(); i++ { | ||||
| 		result := resultsValue.Index(i) | ||||
| 		if indirectScopeValue.Kind() == reflect.Slice { | ||||
| 			valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) | ||||
| 			if objects, found := foreignFieldToObjects[valueString]; found { | ||||
| 				for _, object := range objects { | ||||
| 					object.FieldByName(field.Name).Set(result) | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			scope.Err(field.Set(result)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // handleManyToManyPreload used to preload many to many associations
 | ||||
| func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { | ||||
| 	var ( | ||||
| 		relation         = field.Relationship | ||||
| 		joinTableHandler = relation.JoinTableHandler | ||||
| 		fieldType        = field.Struct.Type.Elem() | ||||
| 		foreignKeyValue  interface{} | ||||
| 		foreignKeyType   = reflect.ValueOf(&foreignKeyValue).Type() | ||||
| 		linkHash         = map[string][]reflect.Value{} | ||||
| 		isPtr            bool | ||||
| 	) | ||||
| 
 | ||||
| 	if fieldType.Kind() == reflect.Ptr { | ||||
| 		isPtr = true | ||||
| 		fieldType = fieldType.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	var sourceKeys = []string{} | ||||
| 	for _, key := range joinTableHandler.SourceForeignKeys() { | ||||
| 		sourceKeys = append(sourceKeys, key.DBName) | ||||
| 	} | ||||
| 
 | ||||
| 	// preload conditions
 | ||||
| 	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) | ||||
| 
 | ||||
| 	// generate query with join table
 | ||||
| 	newScope := scope.New(reflect.New(fieldType).Interface()) | ||||
| 	preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) | ||||
| 
 | ||||
| 	if len(preloadDB.search.selects) == 0 { | ||||
| 		preloadDB = preloadDB.Select("*") | ||||
| 	} | ||||
| 
 | ||||
| 	preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) | ||||
| 
 | ||||
| 	// preload inline conditions
 | ||||
| 	if len(preloadConditions) > 0 { | ||||
| 		preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err := preloadDB.Rows() | ||||
| 
 | ||||
| 	if scope.Err(err) != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 
 | ||||
| 	columns, _ := rows.Columns() | ||||
| 	for rows.Next() { | ||||
| 		var ( | ||||
| 			elem   = reflect.New(fieldType).Elem() | ||||
| 			fields = scope.New(elem.Addr().Interface()).Fields() | ||||
| 		) | ||||
| 
 | ||||
| 		// register foreign keys in join tables
 | ||||
| 		var joinTableFields []*Field | ||||
| 		for _, sourceKey := range sourceKeys { | ||||
| 			joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) | ||||
| 		} | ||||
| 
 | ||||
| 		scope.scan(rows, columns, append(fields, joinTableFields...)) | ||||
| 
 | ||||
| 		scope.New(elem.Addr().Interface()). | ||||
| 			InstanceSet("gorm:skip_query_callback", true). | ||||
| 			callCallbacks(scope.db.parent.callbacks.queries) | ||||
| 
 | ||||
| 		var foreignKeys = make([]interface{}, len(sourceKeys)) | ||||
| 		// generate hashed forkey keys in join table
 | ||||
| 		for idx, joinTableField := range joinTableFields { | ||||
| 			if !joinTableField.Field.IsNil() { | ||||
| 				foreignKeys[idx] = joinTableField.Field.Elem().Interface() | ||||
| 			} | ||||
| 		} | ||||
| 		hashedSourceKeys := toString(foreignKeys) | ||||
| 
 | ||||
| 		if isPtr { | ||||
| 			linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) | ||||
| 		} else { | ||||
| 			linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err := rows.Err(); err != nil { | ||||
| 		scope.Err(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// assign find results
 | ||||
| 	var ( | ||||
| 		indirectScopeValue = scope.IndirectValue() | ||||
| 		fieldsSourceMap    = map[string][]reflect.Value{} | ||||
| 		foreignFieldNames  = []string{} | ||||
| 	) | ||||
| 
 | ||||
| 	for _, dbName := range relation.ForeignFieldNames { | ||||
| 		if field, ok := scope.FieldByName(dbName); ok { | ||||
| 			foreignFieldNames = append(foreignFieldNames, field.Name) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if indirectScopeValue.Kind() == reflect.Slice { | ||||
| 		for j := 0; j < indirectScopeValue.Len(); j++ { | ||||
| 			object := indirect(indirectScopeValue.Index(j)) | ||||
| 			key := toString(getValueFromFields(object, foreignFieldNames)) | ||||
| 			fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) | ||||
| 		} | ||||
| 	} else if indirectScopeValue.IsValid() { | ||||
| 		key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) | ||||
| 		fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) | ||||
| 	} | ||||
| 
 | ||||
| 	for source, fields := range fieldsSourceMap { | ||||
| 		for _, f := range fields { | ||||
| 			//If not 0 this means Value is a pointer and we already added preloaded models to it
 | ||||
| 			if f.Len() != 0 { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			v := reflect.MakeSlice(f.Type(), 0, 0) | ||||
| 			if len(linkHash[source]) > 0 { | ||||
| 				v = reflect.Append(f, linkHash[source]...) | ||||
| 			} | ||||
| 
 | ||||
| 			f.Set(v) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,41 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| // Define callbacks for row query
 | ||||
| func init() { | ||||
| 	DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) | ||||
| } | ||||
| 
 | ||||
| type RowQueryResult struct { | ||||
| 	Row *sql.Row | ||||
| } | ||||
| 
 | ||||
| type RowsQueryResult struct { | ||||
| 	Rows  *sql.Rows | ||||
| 	Error error | ||||
| } | ||||
| 
 | ||||
| // queryCallback used to query data from database
 | ||||
| func rowQueryCallback(scope *Scope) { | ||||
| 	if result, ok := scope.InstanceGet("row_query_result"); ok { | ||||
| 		scope.prepareQuerySQL() | ||||
| 
 | ||||
| 		if str, ok := scope.Get("gorm:query_hint"); ok { | ||||
| 			scope.SQL = fmt.Sprint(str) + scope.SQL | ||||
| 		} | ||||
| 
 | ||||
| 		if str, ok := scope.Get("gorm:query_option"); ok { | ||||
| 			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) | ||||
| 		} | ||||
| 
 | ||||
| 		if rowResult, ok := result.(*RowQueryResult); ok { | ||||
| 			rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) | ||||
| 		} else if rowsResult, ok := result.(*RowsQueryResult); ok { | ||||
| 			rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										170
									
								
								callback_save.go
									
									
									
									
									
								
							
							
						
						
									
										170
									
								
								callback_save.go
									
									
									
									
									
								
							| @ -1,170 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| func beginTransactionCallback(scope *Scope) { | ||||
| 	scope.Begin() | ||||
| } | ||||
| 
 | ||||
| func commitOrRollbackTransactionCallback(scope *Scope) { | ||||
| 	scope.CommitOrRollback() | ||||
| } | ||||
| 
 | ||||
| func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { | ||||
| 	checkTruth := func(value interface{}) bool { | ||||
| 		if v, ok := value.(bool); ok && !v { | ||||
| 			return false | ||||
| 		} | ||||
| 
 | ||||
| 		if v, ok := value.(string); ok { | ||||
| 			v = strings.ToLower(v) | ||||
| 			return v == "true" | ||||
| 		} | ||||
| 
 | ||||
| 		return true | ||||
| 	} | ||||
| 
 | ||||
| 	if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { | ||||
| 		if r = field.Relationship; r != nil { | ||||
| 			autoUpdate, autoCreate, saveReference = true, true, true | ||||
| 
 | ||||
| 			if value, ok := scope.Get("gorm:save_associations"); ok { | ||||
| 				autoUpdate = checkTruth(value) | ||||
| 				autoCreate = autoUpdate | ||||
| 				saveReference = autoUpdate | ||||
| 			} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { | ||||
| 				autoUpdate = checkTruth(value) | ||||
| 				autoCreate = autoUpdate | ||||
| 				saveReference = autoUpdate | ||||
| 			} | ||||
| 
 | ||||
| 			if value, ok := scope.Get("gorm:association_autoupdate"); ok { | ||||
| 				autoUpdate = checkTruth(value) | ||||
| 			} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { | ||||
| 				autoUpdate = checkTruth(value) | ||||
| 			} | ||||
| 
 | ||||
| 			if value, ok := scope.Get("gorm:association_autocreate"); ok { | ||||
| 				autoCreate = checkTruth(value) | ||||
| 			} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { | ||||
| 				autoCreate = checkTruth(value) | ||||
| 			} | ||||
| 
 | ||||
| 			if value, ok := scope.Get("gorm:association_save_reference"); ok { | ||||
| 				saveReference = checkTruth(value) | ||||
| 			} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { | ||||
| 				saveReference = checkTruth(value) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func saveBeforeAssociationsCallback(scope *Scope) { | ||||
| 	for _, field := range scope.Fields() { | ||||
| 		autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) | ||||
| 
 | ||||
| 		if relationship != nil && relationship.Kind == "belongs_to" { | ||||
| 			fieldValue := field.Field.Addr().Interface() | ||||
| 			newScope := scope.New(fieldValue) | ||||
| 
 | ||||
| 			if newScope.PrimaryKeyZero() { | ||||
| 				if autoCreate { | ||||
| 					scope.Err(scope.NewDB().Save(fieldValue).Error) | ||||
| 				} | ||||
| 			} else if autoUpdate { | ||||
| 				scope.Err(scope.NewDB().Save(fieldValue).Error) | ||||
| 			} | ||||
| 
 | ||||
| 			if saveReference { | ||||
| 				if len(relationship.ForeignFieldNames) != 0 { | ||||
| 					// set value's foreign key
 | ||||
| 					for idx, fieldName := range relationship.ForeignFieldNames { | ||||
| 						associationForeignName := relationship.AssociationForeignDBNames[idx] | ||||
| 						if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { | ||||
| 							scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func saveAfterAssociationsCallback(scope *Scope) { | ||||
| 	for _, field := range scope.Fields() { | ||||
| 		autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) | ||||
| 
 | ||||
| 		if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { | ||||
| 			value := field.Field | ||||
| 
 | ||||
| 			switch value.Kind() { | ||||
| 			case reflect.Slice: | ||||
| 				for i := 0; i < value.Len(); i++ { | ||||
| 					newDB := scope.NewDB() | ||||
| 					elem := value.Index(i).Addr().Interface() | ||||
| 					newScope := newDB.NewScope(elem) | ||||
| 
 | ||||
| 					if saveReference { | ||||
| 						if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { | ||||
| 							for idx, fieldName := range relationship.ForeignFieldNames { | ||||
| 								associationForeignName := relationship.AssociationForeignDBNames[idx] | ||||
| 								if f, ok := scope.FieldByName(associationForeignName); ok { | ||||
| 									scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						if relationship.PolymorphicType != "" { | ||||
| 							scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					if newScope.PrimaryKeyZero() { | ||||
| 						if autoCreate { | ||||
| 							scope.Err(newDB.Save(elem).Error) | ||||
| 						} | ||||
| 					} else if autoUpdate { | ||||
| 						scope.Err(newDB.Save(elem).Error) | ||||
| 					} | ||||
| 
 | ||||
| 					if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { | ||||
| 						if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { | ||||
| 							scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			default: | ||||
| 				elem := value.Addr().Interface() | ||||
| 				newScope := scope.New(elem) | ||||
| 
 | ||||
| 				if saveReference { | ||||
| 					if len(relationship.ForeignFieldNames) != 0 { | ||||
| 						for idx, fieldName := range relationship.ForeignFieldNames { | ||||
| 							associationForeignName := relationship.AssociationForeignDBNames[idx] | ||||
| 							if f, ok := scope.FieldByName(associationForeignName); ok { | ||||
| 								scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					if relationship.PolymorphicType != "" { | ||||
| 						scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if newScope.PrimaryKeyZero() { | ||||
| 					if autoCreate { | ||||
| 						scope.Err(scope.NewDB().Save(elem).Error) | ||||
| 					} | ||||
| 				} else if autoUpdate { | ||||
| 					scope.Err(scope.NewDB().Save(elem).Error) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,112 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { | ||||
| 	var names []string | ||||
| 	for _, f := range funcs { | ||||
| 		fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") | ||||
| 		names = append(names, fnames[len(fnames)-1]) | ||||
| 	} | ||||
| 	return reflect.DeepEqual(names, fnames) | ||||
| } | ||||
| 
 | ||||
| func create(s *Scope)        {} | ||||
| func beforeCreate1(s *Scope) {} | ||||
| func beforeCreate2(s *Scope) {} | ||||
| func afterCreate1(s *Scope)  {} | ||||
| func afterCreate2(s *Scope)  {} | ||||
| 
 | ||||
| func TestRegisterCallback(t *testing.T) { | ||||
| 	var callback = &Callback{logger: defaultLogger} | ||||
| 
 | ||||
| 	callback.Create().Register("before_create1", beforeCreate1) | ||||
| 	callback.Create().Register("before_create2", beforeCreate2) | ||||
| 	callback.Create().Register("create", create) | ||||
| 	callback.Create().Register("after_create1", afterCreate1) | ||||
| 	callback.Create().Register("after_create2", afterCreate2) | ||||
| 
 | ||||
| 	if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | ||||
| 		t.Errorf("register callback") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRegisterCallbackWithOrder(t *testing.T) { | ||||
| 	var callback1 = &Callback{logger: defaultLogger} | ||||
| 	callback1.Create().Register("before_create1", beforeCreate1) | ||||
| 	callback1.Create().Register("create", create) | ||||
| 	callback1.Create().Register("after_create1", afterCreate1) | ||||
| 	callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) | ||||
| 	if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | ||||
| 		t.Errorf("register callback with order") | ||||
| 	} | ||||
| 
 | ||||
| 	var callback2 = &Callback{logger: defaultLogger} | ||||
| 
 | ||||
| 	callback2.Update().Register("create", create) | ||||
| 	callback2.Update().Before("create").Register("before_create1", beforeCreate1) | ||||
| 	callback2.Update().After("after_create2").Register("after_create1", afterCreate1) | ||||
| 	callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) | ||||
| 	callback2.Update().Register("after_create2", afterCreate2) | ||||
| 
 | ||||
| 	if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { | ||||
| 		t.Errorf("register callback with order") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRegisterCallbackWithComplexOrder(t *testing.T) { | ||||
| 	var callback1 = &Callback{logger: defaultLogger} | ||||
| 
 | ||||
| 	callback1.Query().Before("after_create1").After("before_create1").Register("create", create) | ||||
| 	callback1.Query().Register("before_create1", beforeCreate1) | ||||
| 	callback1.Query().Register("after_create1", afterCreate1) | ||||
| 
 | ||||
| 	if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { | ||||
| 		t.Errorf("register callback with order") | ||||
| 	} | ||||
| 
 | ||||
| 	var callback2 = &Callback{logger: defaultLogger} | ||||
| 
 | ||||
| 	callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) | ||||
| 	callback2.Delete().Before("create").Register("before_create1", beforeCreate1) | ||||
| 	callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) | ||||
| 	callback2.Delete().Register("after_create1", afterCreate1) | ||||
| 	callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) | ||||
| 
 | ||||
| 	if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { | ||||
| 		t.Errorf("register callback with order") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func replaceCreate(s *Scope) {} | ||||
| 
 | ||||
| func TestReplaceCallback(t *testing.T) { | ||||
| 	var callback = &Callback{logger: defaultLogger} | ||||
| 
 | ||||
| 	callback.Create().Before("after_create1").After("before_create1").Register("create", create) | ||||
| 	callback.Create().Register("before_create1", beforeCreate1) | ||||
| 	callback.Create().Register("after_create1", afterCreate1) | ||||
| 	callback.Create().Replace("create", replaceCreate) | ||||
| 
 | ||||
| 	if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { | ||||
| 		t.Errorf("replace callback") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRemoveCallback(t *testing.T) { | ||||
| 	var callback = &Callback{logger: defaultLogger} | ||||
| 
 | ||||
| 	callback.Create().Before("after_create1").After("before_create1").Register("create", create) | ||||
| 	callback.Create().Register("before_create1", beforeCreate1) | ||||
| 	callback.Create().Register("after_create1", afterCreate1) | ||||
| 	callback.Create().Remove("create") | ||||
| 
 | ||||
| 	if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { | ||||
| 		t.Errorf("remove callback") | ||||
| 	} | ||||
| } | ||||
| @ -1,121 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Define callbacks for updating
 | ||||
| func init() { | ||||
| 	DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:update", updateCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) | ||||
| 	DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) | ||||
| } | ||||
| 
 | ||||
| // assignUpdatingAttributesCallback assign updating attributes to model
 | ||||
| func assignUpdatingAttributesCallback(scope *Scope) { | ||||
| 	if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { | ||||
| 		if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { | ||||
| 			scope.InstanceSet("gorm:update_attrs", updateMaps) | ||||
| 		} else { | ||||
| 			scope.SkipLeft() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
 | ||||
| func beforeUpdateCallback(scope *Scope) { | ||||
| 	if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { | ||||
| 		scope.Err(errors.New("missing WHERE clause while updating")) | ||||
| 		return | ||||
| 	} | ||||
| 	if _, ok := scope.Get("gorm:update_column"); !ok { | ||||
| 		if !scope.HasError() { | ||||
| 			scope.CallMethod("BeforeSave") | ||||
| 		} | ||||
| 		if !scope.HasError() { | ||||
| 			scope.CallMethod("BeforeUpdate") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
 | ||||
| func updateTimeStampForUpdateCallback(scope *Scope) { | ||||
| 	if _, ok := scope.Get("gorm:update_column"); !ok { | ||||
| 		scope.SetColumn("UpdatedAt", scope.db.nowFunc()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // updateCallback the callback used to update data to database
 | ||||
| func updateCallback(scope *Scope) { | ||||
| 	if !scope.HasError() { | ||||
| 		var sqls []string | ||||
| 
 | ||||
| 		if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { | ||||
| 			// Sort the column names so that the generated SQL is the same every time.
 | ||||
| 			updateMap := updateAttrs.(map[string]interface{}) | ||||
| 			var columns []string | ||||
| 			for c := range updateMap { | ||||
| 				columns = append(columns, c) | ||||
| 			} | ||||
| 			sort.Strings(columns) | ||||
| 
 | ||||
| 			for _, column := range columns { | ||||
| 				value := updateMap[column] | ||||
| 				sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) | ||||
| 			} | ||||
| 		} else { | ||||
| 			for _, field := range scope.Fields() { | ||||
| 				if scope.changeableField(field) { | ||||
| 					if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { | ||||
| 						if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { | ||||
| 							sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) | ||||
| 						} | ||||
| 					} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { | ||||
| 						for _, foreignKey := range relationship.ForeignDBNames { | ||||
| 							if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { | ||||
| 								sqls = append(sqls, | ||||
| 									fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		var extraOption string | ||||
| 		if str, ok := scope.Get("gorm:update_option"); ok { | ||||
| 			extraOption = fmt.Sprint(str) | ||||
| 		} | ||||
| 
 | ||||
| 		if len(sqls) > 0 { | ||||
| 			scope.Raw(fmt.Sprintf( | ||||
| 				"UPDATE %v SET %v%v%v", | ||||
| 				scope.QuotedTableName(), | ||||
| 				strings.Join(sqls, ", "), | ||||
| 				addExtraSpaceIfExist(scope.CombinedConditionSql()), | ||||
| 				addExtraSpaceIfExist(extraOption), | ||||
| 			)).Exec() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
 | ||||
| func afterUpdateCallback(scope *Scope) { | ||||
| 	if _, ok := scope.Get("gorm:update_column"); !ok { | ||||
| 		if !scope.HasError() { | ||||
| 			scope.CallMethod("AfterUpdate") | ||||
| 		} | ||||
| 		if !scope.HasError() { | ||||
| 			scope.CallMethod("AfterSave") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,249 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func (s *Product) BeforeCreate() (err error) { | ||||
| 	if s.Code == "Invalid" { | ||||
| 		err = errors.New("invalid product") | ||||
| 	} | ||||
| 	s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeUpdate() (err error) { | ||||
| 	if s.Code == "dont_update" { | ||||
| 		err = errors.New("can't update") | ||||
| 	} | ||||
| 	s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeSave() (err error) { | ||||
| 	if s.Code == "dont_save" { | ||||
| 		err = errors.New("can't save") | ||||
| 	} | ||||
| 	s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterFind() { | ||||
| 	s.AfterFindCallTimes = s.AfterFindCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterCreate(tx *gorm.DB) { | ||||
| 	tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterUpdate() { | ||||
| 	s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterSave() (err error) { | ||||
| 	if s.Code == "after_save_error" { | ||||
| 		err = errors.New("can't save") | ||||
| 	} | ||||
| 	s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeDelete() (err error) { | ||||
| 	if s.Code == "dont_delete" { | ||||
| 		err = errors.New("can't delete") | ||||
| 	} | ||||
| 	s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterDelete() (err error) { | ||||
| 	if s.Code == "after_delete_error" { | ||||
| 		err = errors.New("can't delete") | ||||
| 	} | ||||
| 	s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) GetCallTimes() []int64 { | ||||
| 	return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} | ||||
| } | ||||
| 
 | ||||
| func TestRunCallbacks(t *testing.T) { | ||||
| 	p := Product{Code: "unique_code", Price: 100} | ||||
| 	DB.Save(&p) | ||||
| 
 | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { | ||||
| 		t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where("Code = ?", "unique_code").First(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { | ||||
| 		t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	p.Price = 200 | ||||
| 	DB.Save(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { | ||||
| 		t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	var products []Product | ||||
| 	DB.Find(&products, "code = ?", "unique_code") | ||||
| 	if products[0].AfterFindCallTimes != 2 { | ||||
| 		t.Errorf("AfterFind callbacks should work with slice") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where("Code = ?", "unique_code").First(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { | ||||
| 		t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Delete(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { | ||||
| 		t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { | ||||
| 		t.Errorf("Can't find a deleted record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCallbacksWithErrors(t *testing.T) { | ||||
| 	p := Product{Code: "Invalid", Price: 100} | ||||
| 	if DB.Save(&p).Error == nil { | ||||
| 		t.Errorf("An error from before create callbacks happened when create with invalid value") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { | ||||
| 		t.Errorf("Should not save record that have errors") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { | ||||
| 		t.Errorf("An error from after create callbacks happened when create with invalid value") | ||||
| 	} | ||||
| 
 | ||||
| 	p2 := Product{Code: "update_callback", Price: 100} | ||||
| 	DB.Save(&p2) | ||||
| 
 | ||||
| 	p2.Code = "dont_update" | ||||
| 	if DB.Save(&p2).Error == nil { | ||||
| 		t.Errorf("An error from before update callbacks happened when update with invalid value") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { | ||||
| 		t.Errorf("Record Should not be updated due to errors happened in before update callback") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { | ||||
| 		t.Errorf("Record Should not be updated due to errors happened in before update callback") | ||||
| 	} | ||||
| 
 | ||||
| 	p2.Code = "dont_save" | ||||
| 	if DB.Save(&p2).Error == nil { | ||||
| 		t.Errorf("An error from before save callbacks happened when update with invalid value") | ||||
| 	} | ||||
| 
 | ||||
| 	p3 := Product{Code: "dont_delete", Price: 100} | ||||
| 	DB.Save(&p3) | ||||
| 	if DB.Delete(&p3).Error == nil { | ||||
| 		t.Errorf("An error from before delete callbacks happened when delete") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { | ||||
| 		t.Errorf("An error from before delete callbacks happened") | ||||
| 	} | ||||
| 
 | ||||
| 	p4 := Product{Code: "after_save_error", Price: 100} | ||||
| 	DB.Save(&p4) | ||||
| 	if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { | ||||
| 		t.Errorf("Record should be reverted if get an error in after save callback") | ||||
| 	} | ||||
| 
 | ||||
| 	p5 := Product{Code: "after_delete_error", Price: 100} | ||||
| 	DB.Save(&p5) | ||||
| 	if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | ||||
| 		t.Errorf("Record should be found") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Delete(&p5) | ||||
| 	if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { | ||||
| 		t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGetCallback(t *testing.T) { | ||||
| 	scope := DB.NewScope(nil) | ||||
| 
 | ||||
| 	if DB.Callback().Create().Get("gorm:test_callback") != nil { | ||||
| 		t.Errorf("`gorm:test_callback` should be nil") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) | ||||
| 	callback := DB.Callback().Create().Get("gorm:test_callback") | ||||
| 	if callback == nil { | ||||
| 		t.Errorf("`gorm:test_callback` should be non-nil") | ||||
| 	} | ||||
| 	callback(scope) | ||||
| 	if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { | ||||
| 		t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) | ||||
| 	callback = DB.Callback().Create().Get("gorm:test_callback") | ||||
| 	if callback == nil { | ||||
| 		t.Errorf("`gorm:test_callback` should be non-nil") | ||||
| 	} | ||||
| 	callback(scope) | ||||
| 	if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { | ||||
| 		t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Callback().Create().Remove("gorm:test_callback") | ||||
| 	if DB.Callback().Create().Get("gorm:test_callback") != nil { | ||||
| 		t.Errorf("`gorm:test_callback` should be nil") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) | ||||
| 	callback = DB.Callback().Create().Get("gorm:test_callback") | ||||
| 	if callback == nil { | ||||
| 		t.Errorf("`gorm:test_callback` should be non-nil") | ||||
| 	} | ||||
| 	callback(scope) | ||||
| 	if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { | ||||
| 		t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUseDefaultCallback(t *testing.T) { | ||||
| 	createCallbackName := "gorm:test_use_default_callback_for_create" | ||||
| 	gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { | ||||
| 		// nop
 | ||||
| 	}) | ||||
| 	if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { | ||||
| 		t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) | ||||
| 	} | ||||
| 	gorm.DefaultCallback.Create().Remove(createCallbackName) | ||||
| 	if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { | ||||
| 		t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) | ||||
| 	} | ||||
| 
 | ||||
| 	updateCallbackName := "gorm:test_use_default_callback_for_update" | ||||
| 	scopeValueName := "gorm:test_use_default_callback_for_update_value" | ||||
| 	gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { | ||||
| 		scope.Set(scopeValueName, 1) | ||||
| 	}) | ||||
| 	gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { | ||||
| 		scope.Set(scopeValueName, 2) | ||||
| 	}) | ||||
| 
 | ||||
| 	scope := DB.NewScope(nil) | ||||
| 	callback := gorm.DefaultCallback.Update().Get(updateCallbackName) | ||||
| 	callback(scope) | ||||
| 	if v, ok := scope.Get(scopeValueName); !ok || v != 2 { | ||||
| 		t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										288
									
								
								create_test.go
									
									
									
									
									
								
							
							
						
						
									
										288
									
								
								create_test.go
									
									
									
									
									
								
							| @ -1,288 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/now" | ||||
| ) | ||||
| 
 | ||||
| func TestCreate(t *testing.T) { | ||||
| 	float := 35.03554004971999 | ||||
| 	now := time.Now() | ||||
| 	user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} | ||||
| 
 | ||||
| 	if !DB.NewRecord(user) || !DB.NewRecord(&user) { | ||||
| 		t.Error("User should be new record before create") | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Save(&user).RowsAffected; count != 1 { | ||||
| 		t.Error("There should be one record be affected when create record") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.NewRecord(user) || DB.NewRecord(&user) { | ||||
| 		t.Error("User should not new record after save") | ||||
| 	} | ||||
| 
 | ||||
| 	var newUser User | ||||
| 	if err := DB.First(&newUser, user.Id).Error; err != nil { | ||||
| 		t.Errorf("No error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { | ||||
| 		t.Errorf("User's PasswordHash should be saved ([]byte)") | ||||
| 	} | ||||
| 
 | ||||
| 	if newUser.Age != 18 { | ||||
| 		t.Errorf("User's Age should be saved (int)") | ||||
| 	} | ||||
| 
 | ||||
| 	if newUser.UserNum != Num(111) { | ||||
| 		t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) | ||||
| 	} | ||||
| 
 | ||||
| 	if newUser.Latitude != float { | ||||
| 		t.Errorf("Float64 should not be changed after save") | ||||
| 	} | ||||
| 
 | ||||
| 	if user.CreatedAt.IsZero() { | ||||
| 		t.Errorf("Should have created_at after create") | ||||
| 	} | ||||
| 
 | ||||
| 	if newUser.CreatedAt.IsZero() { | ||||
| 		t.Errorf("Should have created_at after create") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(user).Update("name", "create_user_new_name") | ||||
| 	DB.First(&user, user.Id) | ||||
| 	if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("CreatedAt should not be changed after update") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateEmptyStrut(t *testing.T) { | ||||
| 	type EmptyStruct struct { | ||||
| 		ID uint | ||||
| 	} | ||||
| 	DB.AutoMigrate(&EmptyStruct{}) | ||||
| 
 | ||||
| 	if err := DB.Create(&EmptyStruct{}).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when creating user, but got %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateWithExistingTimestamp(t *testing.T) { | ||||
| 	user := User{Name: "CreateUserExistingTimestamp"} | ||||
| 
 | ||||
| 	timeA := now.MustParse("2016-01-01") | ||||
| 	user.CreatedAt = timeA | ||||
| 	user.UpdatedAt = timeA | ||||
| 	DB.Save(&user) | ||||
| 
 | ||||
| 	if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("CreatedAt should not be changed") | ||||
| 	} | ||||
| 
 | ||||
| 	if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("UpdatedAt should not be changed") | ||||
| 	} | ||||
| 
 | ||||
| 	var newUser User | ||||
| 	DB.First(&newUser, user.Id) | ||||
| 
 | ||||
| 	if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("CreatedAt should not be changed") | ||||
| 	} | ||||
| 
 | ||||
| 	if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("UpdatedAt should not be changed") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateWithNowFuncOverride(t *testing.T) { | ||||
| 	user1 := User{Name: "CreateUserTimestampOverride"} | ||||
| 
 | ||||
| 	timeA := now.MustParse("2016-01-01") | ||||
| 
 | ||||
| 	// do DB.New() because we don't want this test to affect other tests
 | ||||
| 	db1 := DB.New() | ||||
| 	// set the override to use static timeA
 | ||||
| 	db1.SetNowFuncOverride(func() time.Time { | ||||
| 		return timeA | ||||
| 	}) | ||||
| 	// call .New again to check the override is carried over as well during clone
 | ||||
| 	db1 = db1.New() | ||||
| 
 | ||||
| 	db1.Save(&user1) | ||||
| 
 | ||||
| 	if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("CreatedAt be using the nowFuncOverride") | ||||
| 	} | ||||
| 	if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("UpdatedAt be using the nowFuncOverride") | ||||
| 	} | ||||
| 
 | ||||
| 	// now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
 | ||||
| 	// to make sure that setting it only affected the above instance
 | ||||
| 
 | ||||
| 	user2 := User{Name: "CreateUserTimestampOverrideNoMore"} | ||||
| 
 | ||||
| 	db2 := DB.New() | ||||
| 
 | ||||
| 	db2.Save(&user2) | ||||
| 
 | ||||
| 	if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("CreatedAt no longer be using the nowFuncOverride") | ||||
| 	} | ||||
| 	if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { | ||||
| 		t.Errorf("UpdatedAt no longer be using the nowFuncOverride") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type AutoIncrementUser struct { | ||||
| 	User | ||||
| 	Sequence uint `gorm:"AUTO_INCREMENT"` | ||||
| } | ||||
| 
 | ||||
| func TestCreateWithAutoIncrement(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { | ||||
| 		t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.AutoMigrate(&AutoIncrementUser{}) | ||||
| 
 | ||||
| 	user1 := AutoIncrementUser{} | ||||
| 	user2 := AutoIncrementUser{} | ||||
| 
 | ||||
| 	DB.Create(&user1) | ||||
| 	DB.Create(&user2) | ||||
| 
 | ||||
| 	if user2.Sequence-user1.Sequence != 1 { | ||||
| 		t.Errorf("Auto increment should apply on Sequence") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateWithNoGORMPrimayKey(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { | ||||
| 		t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") | ||||
| 	} | ||||
| 
 | ||||
| 	jt := JoinTable{From: 1, To: 2} | ||||
| 	err := DB.Create(&jt).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { | ||||
| 	animal := Animal{Name: "Ferdinand"} | ||||
| 	if DB.Save(&animal).Error != nil { | ||||
| 		t.Errorf("No error should happen when create a record without std primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	if animal.Counter == 0 { | ||||
| 		t.Errorf("No std primary key should be filled value after create") | ||||
| 	} | ||||
| 
 | ||||
| 	if animal.Name != "Ferdinand" { | ||||
| 		t.Errorf("Default value should be overrided") | ||||
| 	} | ||||
| 
 | ||||
| 	// Test create with default value not overrided
 | ||||
| 	an := Animal{From: "nerdz"} | ||||
| 
 | ||||
| 	if DB.Save(&an).Error != nil { | ||||
| 		t.Errorf("No error should happen when create an record without std primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	// We must fetch the value again, to have the default fields updated
 | ||||
| 	// (We can't do this in the update statements, since sql default can be expressions
 | ||||
| 	// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
 | ||||
| 	DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) | ||||
| 
 | ||||
| 	if an.Name != "galeone" { | ||||
| 		t.Errorf("Default value should fill the field. But got %v", an.Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestAnonymousScanner(t *testing.T) { | ||||
| 	user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} | ||||
| 	DB.Save(&user) | ||||
| 
 | ||||
| 	var user2 User | ||||
| 	DB.First(&user2, "name = ?", "anonymous_scanner") | ||||
| 	if user2.Role.Name != "admin" { | ||||
| 		t.Errorf("Should be able to get anonymous scanner") | ||||
| 	} | ||||
| 
 | ||||
| 	if !user2.Role.IsAdmin() { | ||||
| 		t.Errorf("Should be able to get anonymous scanner") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestAnonymousField(t *testing.T) { | ||||
| 	user := User{Name: "anonymous_field", Company: Company{Name: "company"}} | ||||
| 	DB.Save(&user) | ||||
| 
 | ||||
| 	var user2 User | ||||
| 	DB.First(&user2, "name = ?", "anonymous_field") | ||||
| 	DB.Model(&user2).Related(&user2.Company) | ||||
| 	if user2.Company.Name != "company" { | ||||
| 		t.Errorf("Should be able to get anonymous field") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithCreate(t *testing.T) { | ||||
| 	user := getPreparedUser("select_user", "select_with_create") | ||||
| 	DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) | ||||
| 
 | ||||
| 	var queryuser User | ||||
| 	DB.Preload("BillingAddress").Preload("ShippingAddress"). | ||||
| 		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) | ||||
| 
 | ||||
| 	if queryuser.Name != user.Name || queryuser.Age == user.Age { | ||||
| 		t.Errorf("Should only create users with name column") | ||||
| 	} | ||||
| 
 | ||||
| 	if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || | ||||
| 		queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { | ||||
| 		t.Errorf("Should only create selected relationships") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithCreate(t *testing.T) { | ||||
| 	user := getPreparedUser("omit_user", "omit_with_create") | ||||
| 	DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) | ||||
| 
 | ||||
| 	var queryuser User | ||||
| 	DB.Preload("BillingAddress").Preload("ShippingAddress"). | ||||
| 		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) | ||||
| 
 | ||||
| 	if queryuser.Name == user.Name || queryuser.Age != user.Age { | ||||
| 		t.Errorf("Should only create users with age column") | ||||
| 	} | ||||
| 
 | ||||
| 	if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || | ||||
| 		queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { | ||||
| 		t.Errorf("Should not create omitted relationships") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateIgnore(t *testing.T) { | ||||
| 	float := 35.03554004971999 | ||||
| 	now := time.Now() | ||||
| 	user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} | ||||
| 
 | ||||
| 	if !DB.NewRecord(user) || !DB.NewRecord(&user) { | ||||
| 		t.Error("User should be new record before create") | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Create(&user).RowsAffected; count != 1 { | ||||
| 		t.Error("There should be one record be affected when create record") | ||||
| 	} | ||||
| 	if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil { | ||||
| 		t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ") | ||||
| 	} | ||||
| } | ||||
| @ -1,357 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| type CustomizeColumn struct { | ||||
| 	ID   int64      `gorm:"column:mapped_id; primary_key:yes"` | ||||
| 	Name string     `gorm:"column:mapped_name"` | ||||
| 	Date *time.Time `gorm:"column:mapped_time"` | ||||
| } | ||||
| 
 | ||||
| // Make sure an ignored field does not interfere with another field's custom
 | ||||
| // column name that matches the ignored field.
 | ||||
| type CustomColumnAndIgnoredFieldClash struct { | ||||
| 	Body    string `sql:"-"` | ||||
| 	RawBody string `gorm:"column:body"` | ||||
| } | ||||
| 
 | ||||
| func TestCustomizeColumn(t *testing.T) { | ||||
| 	col := "mapped_name" | ||||
| 	DB.DropTable(&CustomizeColumn{}) | ||||
| 	DB.AutoMigrate(&CustomizeColumn{}) | ||||
| 
 | ||||
| 	scope := DB.NewScope(&CustomizeColumn{}) | ||||
| 	if !scope.Dialect().HasColumn(scope.TableName(), col) { | ||||
| 		t.Errorf("CustomizeColumn should have column %s", col) | ||||
| 	} | ||||
| 
 | ||||
| 	col = "mapped_id" | ||||
| 	if scope.PrimaryKey() != col { | ||||
| 		t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) | ||||
| 	} | ||||
| 
 | ||||
| 	expected := "foo" | ||||
| 	now := time.Now() | ||||
| 	cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} | ||||
| 
 | ||||
| 	if count := DB.Create(&cc).RowsAffected; count != 1 { | ||||
| 		t.Error("There should be one record be affected when create record") | ||||
| 	} | ||||
| 
 | ||||
| 	var cc1 CustomizeColumn | ||||
| 	DB.First(&cc1, 666) | ||||
| 
 | ||||
| 	if cc1.Name != expected { | ||||
| 		t.Errorf("Failed to query CustomizeColumn") | ||||
| 	} | ||||
| 
 | ||||
| 	cc.Name = "bar" | ||||
| 	DB.Save(&cc) | ||||
| 
 | ||||
| 	var cc2 CustomizeColumn | ||||
| 	DB.First(&cc2, 666) | ||||
| 	if cc2.Name != "bar" { | ||||
| 		t.Errorf("Failed to query CustomizeColumn") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { | ||||
| 	DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) | ||||
| 	if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { | ||||
| 		t.Errorf("Should not raise error: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type CustomizePerson struct { | ||||
| 	IdPerson string             `gorm:"column:idPerson;primary_key:true"` | ||||
| 	Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` | ||||
| } | ||||
| 
 | ||||
| type CustomizeAccount struct { | ||||
| 	IdAccount string `gorm:"column:idAccount;primary_key:true"` | ||||
| 	Name      string | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithCustomizedColumn(t *testing.T) { | ||||
| 	DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") | ||||
| 	DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) | ||||
| 
 | ||||
| 	account := CustomizeAccount{IdAccount: "account", Name: "id1"} | ||||
| 	person := CustomizePerson{ | ||||
| 		IdPerson: "person", | ||||
| 		Accounts: []CustomizeAccount{account}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&account).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&person).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var person1 CustomizePerson | ||||
| 	scope := DB.NewScope(nil) | ||||
| 	if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { | ||||
| 		t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { | ||||
| 		t.Errorf("should preload correct accounts") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type CustomizeUser struct { | ||||
| 	gorm.Model | ||||
| 	Email string `sql:"column:email_address"` | ||||
| } | ||||
| 
 | ||||
| type CustomizeInvitation struct { | ||||
| 	gorm.Model | ||||
| 	Address string         `sql:"column:invitation"` | ||||
| 	Person  *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` | ||||
| } | ||||
| 
 | ||||
| func TestOneToOneWithCustomizedColumn(t *testing.T) { | ||||
| 	DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) | ||||
| 	DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) | ||||
| 
 | ||||
| 	user := CustomizeUser{ | ||||
| 		Email: "hello@example.com", | ||||
| 	} | ||||
| 	invitation := CustomizeInvitation{ | ||||
| 		Address: "hello@example.com", | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&user) | ||||
| 	DB.Create(&invitation) | ||||
| 
 | ||||
| 	var invitation2 CustomizeInvitation | ||||
| 	if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if invitation2.Person.Email != user.Email { | ||||
| 		t.Errorf("Should preload one to one relation with customize foreign keys") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type PromotionDiscount struct { | ||||
| 	gorm.Model | ||||
| 	Name     string | ||||
| 	Coupons  []*PromotionCoupon `gorm:"ForeignKey:discount_id"` | ||||
| 	Rule     *PromotionRule     `gorm:"ForeignKey:discount_id"` | ||||
| 	Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` | ||||
| } | ||||
| 
 | ||||
| type PromotionBenefit struct { | ||||
| 	gorm.Model | ||||
| 	Name        string | ||||
| 	PromotionID uint | ||||
| 	Discount    PromotionDiscount `gorm:"ForeignKey:promotion_id"` | ||||
| } | ||||
| 
 | ||||
| type PromotionCoupon struct { | ||||
| 	gorm.Model | ||||
| 	Code       string | ||||
| 	DiscountID uint | ||||
| 	Discount   PromotionDiscount | ||||
| } | ||||
| 
 | ||||
| type PromotionRule struct { | ||||
| 	gorm.Model | ||||
| 	Name       string | ||||
| 	Begin      *time.Time | ||||
| 	End        *time.Time | ||||
| 	DiscountID uint | ||||
| 	Discount   *PromotionDiscount | ||||
| } | ||||
| 
 | ||||
| func TestOneToManyWithCustomizedColumn(t *testing.T) { | ||||
| 	DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) | ||||
| 	DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) | ||||
| 
 | ||||
| 	discount := PromotionDiscount{ | ||||
| 		Name: "Happy New Year", | ||||
| 		Coupons: []*PromotionCoupon{ | ||||
| 			{Code: "newyear1"}, | ||||
| 			{Code: "newyear2"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&discount).Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var discount1 PromotionDiscount | ||||
| 	if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(discount.Coupons) != 2 { | ||||
| 		t.Errorf("should find two coupons") | ||||
| 	} | ||||
| 
 | ||||
| 	var coupon PromotionCoupon | ||||
| 	if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if coupon.Discount.Name != "Happy New Year" { | ||||
| 		t.Errorf("should preload discount from coupon") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestHasOneWithPartialCustomizedColumn(t *testing.T) { | ||||
| 	DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) | ||||
| 	DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) | ||||
| 
 | ||||
| 	var begin = time.Now() | ||||
| 	var end = time.Now().Add(24 * time.Hour) | ||||
| 	discount := PromotionDiscount{ | ||||
| 		Name: "Happy New Year 2", | ||||
| 		Rule: &PromotionRule{ | ||||
| 			Name:  "time_limited", | ||||
| 			Begin: &begin, | ||||
| 			End:   &end, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&discount).Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var discount1 PromotionDiscount | ||||
| 	if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("Should be able to preload Rule") | ||||
| 	} | ||||
| 
 | ||||
| 	var rule PromotionRule | ||||
| 	if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if rule.Discount.Name != "Happy New Year 2" { | ||||
| 		t.Errorf("should preload discount from rule") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { | ||||
| 	DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) | ||||
| 	DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) | ||||
| 
 | ||||
| 	discount := PromotionDiscount{ | ||||
| 		Name: "Happy New Year 3", | ||||
| 		Benefits: []PromotionBenefit{ | ||||
| 			{Name: "free cod"}, | ||||
| 			{Name: "free shipping"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&discount).Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var discount1 PromotionDiscount | ||||
| 	if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(discount.Benefits) != 2 { | ||||
| 		t.Errorf("should find two benefits") | ||||
| 	} | ||||
| 
 | ||||
| 	var benefit PromotionBenefit | ||||
| 	if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { | ||||
| 		t.Errorf("no error should happen but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if benefit.Discount.Name != "Happy New Year 3" { | ||||
| 		t.Errorf("should preload discount from coupon") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type SelfReferencingUser struct { | ||||
| 	gorm.Model | ||||
| 	Name    string | ||||
| 	Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` | ||||
| } | ||||
| 
 | ||||
| func TestSelfReferencingMany2ManyColumn(t *testing.T) { | ||||
| 	DB.DropTable(&SelfReferencingUser{}, "UserFriends") | ||||
| 	DB.AutoMigrate(&SelfReferencingUser{}) | ||||
| 	if !DB.HasTable("UserFriends") { | ||||
| 		t.Errorf("auto migrate error, table UserFriends should be created") | ||||
| 	} | ||||
| 
 | ||||
| 	friend1 := SelfReferencingUser{Name: "friend1_m2m"} | ||||
| 	if err := DB.Create(&friend1).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	friend2 := SelfReferencingUser{Name: "friend2_m2m"} | ||||
| 	if err := DB.Create(&friend2).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user := SelfReferencingUser{ | ||||
| 		Name:    "self_m2m", | ||||
| 		Friends: []*SelfReferencingUser{&friend1, &friend2}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&user).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&user).Association("Friends").Count() != 2 { | ||||
| 		t.Errorf("Should find created friends correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	var count int | ||||
| 	if err := DB.Table("UserFriends").Count(&count).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 	if count == 0 { | ||||
| 		t.Errorf("table UserFriends should have records") | ||||
| 	} | ||||
| 
 | ||||
| 	var newUser = SelfReferencingUser{} | ||||
| 
 | ||||
| 	if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { | ||||
| 		t.Errorf("no error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(newUser.Friends) != 2 { | ||||
| 		t.Errorf("Should preload created frineds for self reference m2m") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) | ||||
| 	if DB.Model(&user).Association("Friends").Count() != 3 { | ||||
| 		t.Errorf("Should find created friends correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) | ||||
| 	if DB.Model(&user).Association("Friends").Count() != 1 { | ||||
| 		t.Errorf("Should find created friends correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	friend := SelfReferencingUser{} | ||||
| 	DB.Model(&newUser).Association("Friends").Find(&friend) | ||||
| 	if friend.Name != "friend4_m2m" { | ||||
| 		t.Errorf("Should find created friends correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&newUser).Association("Friends").Delete(friend) | ||||
| 	if DB.Model(&user).Association("Friends").Count() != 0 { | ||||
| 		t.Errorf("All friends should be deleted") | ||||
| 	} | ||||
| } | ||||
| @ -1,91 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestDelete(t *testing.T) { | ||||
| 	user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} | ||||
| 	DB.Save(&user1) | ||||
| 	DB.Save(&user2) | ||||
| 
 | ||||
| 	if err := DB.Delete(&user1).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when delete a record, err=%s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("User can't be found after delete") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("Other users that not deleted should be found-able") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestInlineDelete(t *testing.T) { | ||||
| 	user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} | ||||
| 	DB.Save(&user1) | ||||
| 	DB.Save(&user2) | ||||
| 
 | ||||
| 	if DB.Delete(&User{}, user1.Id).Error != nil { | ||||
| 		t.Errorf("No error should happen when delete a record") | ||||
| 	} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("User can't be found after delete") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when delete a record, err=%s", err) | ||||
| 	} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("User can't be found after delete") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSoftDelete(t *testing.T) { | ||||
| 	type User struct { | ||||
| 		Id        int64 | ||||
| 		Name      string | ||||
| 		DeletedAt *time.Time | ||||
| 	} | ||||
| 	DB.AutoMigrate(&User{}) | ||||
| 
 | ||||
| 	user := User{Name: "soft_delete"} | ||||
| 	DB.Save(&user) | ||||
| 	DB.Delete(&user) | ||||
| 
 | ||||
| 	if DB.First(&User{}, "name = ?", user.Name).Error == nil { | ||||
| 		t.Errorf("Can't find a soft deleted record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 		t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Unscoped().Delete(&user) | ||||
| 	if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { | ||||
| 		t.Errorf("Can't find permanently deleted record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) { | ||||
| 	creditCard := CreditCard{Number: "411111111234567"} | ||||
| 	DB.Save(&creditCard) | ||||
| 	DB.Delete(&creditCard) | ||||
| 
 | ||||
| 	if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" { | ||||
| 		t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil { | ||||
| 		t.Errorf("Can't find a soft deleted record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil { | ||||
| 		t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Unscoped().Delete(&creditCard) | ||||
| 	if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() { | ||||
| 		t.Errorf("Can't find permanently deleted record") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										147
									
								
								dialect.go
									
									
									
									
									
								
							
							
						
						
									
										147
									
								
								dialect.go
									
									
									
									
									
								
							| @ -1,147 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Dialect interface contains behaviors that differ across SQL database
 | ||||
| type Dialect interface { | ||||
| 	// GetName get dialect's name
 | ||||
| 	GetName() string | ||||
| 
 | ||||
| 	// SetDB set db for dialect
 | ||||
| 	SetDB(db SQLCommon) | ||||
| 
 | ||||
| 	// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
 | ||||
| 	BindVar(i int) string | ||||
| 	// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
 | ||||
| 	Quote(key string) string | ||||
| 	// DataTypeOf return data's sql type
 | ||||
| 	DataTypeOf(field *StructField) string | ||||
| 
 | ||||
| 	// HasIndex check has index or not
 | ||||
| 	HasIndex(tableName string, indexName string) bool | ||||
| 	// HasForeignKey check has foreign key or not
 | ||||
| 	HasForeignKey(tableName string, foreignKeyName string) bool | ||||
| 	// RemoveIndex remove index
 | ||||
| 	RemoveIndex(tableName string, indexName string) error | ||||
| 	// HasTable check has table or not
 | ||||
| 	HasTable(tableName string) bool | ||||
| 	// HasColumn check has column or not
 | ||||
| 	HasColumn(tableName string, columnName string) bool | ||||
| 	// ModifyColumn modify column's type
 | ||||
| 	ModifyColumn(tableName string, columnName string, typ string) error | ||||
| 
 | ||||
| 	// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
 | ||||
| 	LimitAndOffsetSQL(limit, offset interface{}) (string, error) | ||||
| 	// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
 | ||||
| 	SelectFromDummyTable() string | ||||
| 	// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
 | ||||
| 	LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string | ||||
| 	// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
 | ||||
| 	LastInsertIDReturningSuffix(tableName, columnName string) string | ||||
| 	// DefaultValueStr
 | ||||
| 	DefaultValueStr() string | ||||
| 
 | ||||
| 	// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
 | ||||
| 	BuildKeyName(kind, tableName string, fields ...string) string | ||||
| 
 | ||||
| 	// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
 | ||||
| 	NormalizeIndexAndColumn(indexName, columnName string) (string, string) | ||||
| 
 | ||||
| 	// CurrentDatabase return current database name
 | ||||
| 	CurrentDatabase() string | ||||
| } | ||||
| 
 | ||||
| var dialectsMap = map[string]Dialect{} | ||||
| 
 | ||||
| func newDialect(name string, db SQLCommon) Dialect { | ||||
| 	if value, ok := dialectsMap[name]; ok { | ||||
| 		dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) | ||||
| 		dialect.SetDB(db) | ||||
| 		return dialect | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) | ||||
| 	commontDialect := &commonDialect{} | ||||
| 	commontDialect.SetDB(db) | ||||
| 	return commontDialect | ||||
| } | ||||
| 
 | ||||
| // RegisterDialect register new dialect
 | ||||
| func RegisterDialect(name string, dialect Dialect) { | ||||
| 	dialectsMap[name] = dialect | ||||
| } | ||||
| 
 | ||||
| // GetDialect gets the dialect for the specified dialect name
 | ||||
| func GetDialect(name string) (dialect Dialect, ok bool) { | ||||
| 	dialect, ok = dialectsMap[name] | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // ParseFieldStructForDialect get field's sql data type
 | ||||
| var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { | ||||
| 	// Get redirected field type
 | ||||
| 	var ( | ||||
| 		reflectType = field.Struct.Type | ||||
| 		dataType, _ = field.TagSettingsGet("TYPE") | ||||
| 	) | ||||
| 
 | ||||
| 	for reflectType.Kind() == reflect.Ptr { | ||||
| 		reflectType = reflectType.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	// Get redirected field value
 | ||||
| 	fieldValue = reflect.Indirect(reflect.New(reflectType)) | ||||
| 
 | ||||
| 	if gormDataType, ok := fieldValue.Interface().(interface { | ||||
| 		GormDataType(Dialect) string | ||||
| 	}); ok { | ||||
| 		dataType = gormDataType.GormDataType(dialect) | ||||
| 	} | ||||
| 
 | ||||
| 	// Get scanner's real value
 | ||||
| 	if dataType == "" { | ||||
| 		var getScannerValue func(reflect.Value) | ||||
| 		getScannerValue = func(value reflect.Value) { | ||||
| 			fieldValue = value | ||||
| 			if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { | ||||
| 				getScannerValue(fieldValue.Field(0)) | ||||
| 			} | ||||
| 		} | ||||
| 		getScannerValue(fieldValue) | ||||
| 	} | ||||
| 
 | ||||
| 	// Default Size
 | ||||
| 	if num, ok := field.TagSettingsGet("SIZE"); ok { | ||||
| 		size, _ = strconv.Atoi(num) | ||||
| 	} else { | ||||
| 		size = 255 | ||||
| 	} | ||||
| 
 | ||||
| 	// Default type from tag setting
 | ||||
| 	notNull, _ := field.TagSettingsGet("NOT NULL") | ||||
| 	unique, _ := field.TagSettingsGet("UNIQUE") | ||||
| 	additionalType = notNull + " " + unique | ||||
| 	if value, ok := field.TagSettingsGet("DEFAULT"); ok { | ||||
| 		additionalType = additionalType + " DEFAULT " + value | ||||
| 	} | ||||
| 
 | ||||
| 	if value, ok := field.TagSettingsGet("COMMENT"); ok { | ||||
| 		additionalType = additionalType + " COMMENT " + value | ||||
| 	} | ||||
| 
 | ||||
| 	return fieldValue, dataType, size, strings.TrimSpace(additionalType) | ||||
| } | ||||
| 
 | ||||
| func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { | ||||
| 	if strings.Contains(tableName, ".") { | ||||
| 		splitStrings := strings.SplitN(tableName, ".", 2) | ||||
| 		return splitStrings[0], splitStrings[1] | ||||
| 	} | ||||
| 	return dialect.CurrentDatabase(), tableName | ||||
| } | ||||
| @ -1,196 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") | ||||
| 
 | ||||
| // DefaultForeignKeyNamer contains the default foreign key name generator method
 | ||||
| type DefaultForeignKeyNamer struct { | ||||
| } | ||||
| 
 | ||||
| type commonDialect struct { | ||||
| 	db SQLCommon | ||||
| 	DefaultForeignKeyNamer | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	RegisterDialect("common", &commonDialect{}) | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) GetName() string { | ||||
| 	return "common" | ||||
| } | ||||
| 
 | ||||
| func (s *commonDialect) SetDB(db SQLCommon) { | ||||
| 	s.db = db | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) BindVar(i int) string { | ||||
| 	return "$$$" // ?
 | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) Quote(key string) string { | ||||
| 	return fmt.Sprintf(`"%s"`, key) | ||||
| } | ||||
| 
 | ||||
| func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { | ||||
| 	if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { | ||||
| 		return strings.ToLower(value) != "false" | ||||
| 	} | ||||
| 	return field.IsPrimaryKey | ||||
| } | ||||
| 
 | ||||
| func (s *commonDialect) DataTypeOf(field *StructField) string { | ||||
| 	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		switch dataValue.Kind() { | ||||
| 		case reflect.Bool: | ||||
| 			sqlType = "BOOLEAN" | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				sqlType = "INTEGER AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "INTEGER" | ||||
| 			} | ||||
| 		case reflect.Int64, reflect.Uint64: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				sqlType = "BIGINT AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "BIGINT" | ||||
| 			} | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			sqlType = "FLOAT" | ||||
| 		case reflect.String: | ||||
| 			if size > 0 && size < 65532 { | ||||
| 				sqlType = fmt.Sprintf("VARCHAR(%d)", size) | ||||
| 			} else { | ||||
| 				sqlType = "VARCHAR(65532)" | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			if _, ok := dataValue.Interface().(time.Time); ok { | ||||
| 				sqlType = "TIMESTAMP" | ||||
| 			} | ||||
| 		default: | ||||
| 			if _, ok := dataValue.Interface().([]byte); ok { | ||||
| 				if size > 0 && size < 65532 { | ||||
| 					sqlType = fmt.Sprintf("BINARY(%d)", size) | ||||
| 				} else { | ||||
| 					sqlType = "BINARY(65532)" | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.TrimSpace(additionalType) == "" { | ||||
| 		return sqlType | ||||
| 	} | ||||
| 	return fmt.Sprintf("%v %v", sqlType, additionalType) | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) HasIndex(tableName string, indexName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) RemoveIndex(tableName string, indexName string) error { | ||||
| 	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) HasTable(tableName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) HasColumn(tableName string, columnName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { | ||||
| 	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s commonDialect) CurrentDatabase() (name string) { | ||||
| 	s.db.QueryRow("SELECT DATABASE()").Scan(&name) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // LimitAndOffsetSQL return generated SQL with Limit and Offset
 | ||||
| func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { | ||||
| 	if limit != nil { | ||||
| 		if parsedLimit, err := s.parseInt(limit); err != nil { | ||||
| 			return "", err | ||||
| 		} else if parsedLimit >= 0 { | ||||
| 			sql += fmt.Sprintf(" LIMIT %d", parsedLimit) | ||||
| 		} | ||||
| 	} | ||||
| 	if offset != nil { | ||||
| 		if parsedOffset, err := s.parseInt(offset); err != nil { | ||||
| 			return "", err | ||||
| 		} else if parsedOffset >= 0 { | ||||
| 			sql += fmt.Sprintf(" OFFSET %d", parsedOffset) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) SelectFromDummyTable() string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) DefaultValueStr() string { | ||||
| 	return "DEFAULT VALUES" | ||||
| } | ||||
| 
 | ||||
| // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
 | ||||
| func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { | ||||
| 	keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) | ||||
| 	keyName = keyNameRegex.ReplaceAllString(keyName, "_") | ||||
| 	return keyName | ||||
| } | ||||
| 
 | ||||
| // NormalizeIndexAndColumn returns argument's index name and column name without doing anything
 | ||||
| func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { | ||||
| 	return indexName, columnName | ||||
| } | ||||
| 
 | ||||
| func (commonDialect) parseInt(value interface{}) (int64, error) { | ||||
| 	return strconv.ParseInt(fmt.Sprint(value), 0, 0) | ||||
| } | ||||
| 
 | ||||
| // IsByteArrayOrSlice returns true of the reflected value is an array or slice
 | ||||
| func IsByteArrayOrSlice(value reflect.Value) bool { | ||||
| 	return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) | ||||
| } | ||||
							
								
								
									
										246
									
								
								dialect_mysql.go
									
									
									
									
									
								
							
							
						
						
									
										246
									
								
								dialect_mysql.go
									
									
									
									
									
								
							| @ -1,246 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/sha1" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
| 
 | ||||
| var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) | ||||
| 
 | ||||
| type mysql struct { | ||||
| 	commonDialect | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	RegisterDialect("mysql", &mysql{}) | ||||
| } | ||||
| 
 | ||||
| func (mysql) GetName() string { | ||||
| 	return "mysql" | ||||
| } | ||||
| 
 | ||||
| func (mysql) Quote(key string) string { | ||||
| 	return fmt.Sprintf("`%s`", key) | ||||
| } | ||||
| 
 | ||||
| // Get Data Type for MySQL Dialect
 | ||||
| func (s *mysql) DataTypeOf(field *StructField) string { | ||||
| 	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | ||||
| 
 | ||||
| 	// MySQL allows only one auto increment column per table, and it must
 | ||||
| 	// be a KEY column.
 | ||||
| 	if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { | ||||
| 		if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { | ||||
| 			field.TagSettingsDelete("AUTO_INCREMENT") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		switch dataValue.Kind() { | ||||
| 		case reflect.Bool: | ||||
| 			sqlType = "boolean" | ||||
| 		case reflect.Int8: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "tinyint AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "tinyint" | ||||
| 			} | ||||
| 		case reflect.Int, reflect.Int16, reflect.Int32: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "int AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "int" | ||||
| 			} | ||||
| 		case reflect.Uint8: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "tinyint unsigned AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "tinyint unsigned" | ||||
| 			} | ||||
| 		case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "int unsigned AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "int unsigned" | ||||
| 			} | ||||
| 		case reflect.Int64: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "bigint AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "bigint" | ||||
| 			} | ||||
| 		case reflect.Uint64: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "bigint unsigned AUTO_INCREMENT" | ||||
| 			} else { | ||||
| 				sqlType = "bigint unsigned" | ||||
| 			} | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			sqlType = "double" | ||||
| 		case reflect.String: | ||||
| 			if size > 0 && size < 65532 { | ||||
| 				sqlType = fmt.Sprintf("varchar(%d)", size) | ||||
| 			} else { | ||||
| 				sqlType = "longtext" | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			if _, ok := dataValue.Interface().(time.Time); ok { | ||||
| 				precision := "" | ||||
| 				if p, ok := field.TagSettingsGet("PRECISION"); ok { | ||||
| 					precision = fmt.Sprintf("(%s)", p) | ||||
| 				} | ||||
| 
 | ||||
| 				if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { | ||||
| 					sqlType = fmt.Sprintf("DATETIME%v", precision) | ||||
| 				} else { | ||||
| 					sqlType = fmt.Sprintf("DATETIME%v NULL", precision) | ||||
| 				} | ||||
| 			} | ||||
| 		default: | ||||
| 			if IsByteArrayOrSlice(dataValue) { | ||||
| 				if size > 0 && size < 65532 { | ||||
| 					sqlType = fmt.Sprintf("varbinary(%d)", size) | ||||
| 				} else { | ||||
| 					sqlType = "longblob" | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.TrimSpace(additionalType) == "" { | ||||
| 		return sqlType | ||||
| 	} | ||||
| 	return fmt.Sprintf("%v %v", sqlType, additionalType) | ||||
| } | ||||
| 
 | ||||
| func (s mysql) RemoveIndex(tableName string, indexName string) error { | ||||
| 	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { | ||||
| 	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { | ||||
| 	if limit != nil { | ||||
| 		parsedLimit, err := s.parseInt(limit) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		if parsedLimit >= 0 { | ||||
| 			sql += fmt.Sprintf(" LIMIT %d", parsedLimit) | ||||
| 
 | ||||
| 			if offset != nil { | ||||
| 				parsedOffset, err := s.parseInt(offset) | ||||
| 				if err != nil { | ||||
| 					return "", err | ||||
| 				} | ||||
| 				if parsedOffset >= 0 { | ||||
| 					sql += fmt.Sprintf(" OFFSET %d", parsedOffset) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s mysql) HasTable(tableName string) bool { | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	var name string | ||||
| 	// allow mysql database name with '-' character
 | ||||
| 	if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return false | ||||
| 		} | ||||
| 		panic(err) | ||||
| 	} else { | ||||
| 		return true | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s mysql) HasIndex(tableName string, indexName string) bool { | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { | ||||
| 		panic(err) | ||||
| 	} else { | ||||
| 		defer rows.Close() | ||||
| 		return rows.Next() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s mysql) HasColumn(tableName string, columnName string) bool { | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { | ||||
| 		panic(err) | ||||
| 	} else { | ||||
| 		defer rows.Close() | ||||
| 		return rows.Next() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s mysql) CurrentDatabase() (name string) { | ||||
| 	s.db.QueryRow("SELECT DATABASE()").Scan(&name) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (mysql) SelectFromDummyTable() string { | ||||
| 	return "FROM DUAL" | ||||
| } | ||||
| 
 | ||||
| func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { | ||||
| 	keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) | ||||
| 	if utf8.RuneCountInString(keyName) <= 64 { | ||||
| 		return keyName | ||||
| 	} | ||||
| 	h := sha1.New() | ||||
| 	h.Write([]byte(keyName)) | ||||
| 	bs := h.Sum(nil) | ||||
| 
 | ||||
| 	// sha1 is 40 characters, keep first 24 characters of destination
 | ||||
| 	destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) | ||||
| 	if len(destRunes) > 24 { | ||||
| 		destRunes = destRunes[:24] | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprintf("%s%x", string(destRunes), bs) | ||||
| } | ||||
| 
 | ||||
| // NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
 | ||||
| func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { | ||||
| 	submatch := mysqlIndexRegex.FindStringSubmatch(indexName) | ||||
| 	if len(submatch) != 3 { | ||||
| 		return indexName, columnName | ||||
| 	} | ||||
| 	indexName = submatch[1] | ||||
| 	columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) | ||||
| 	return indexName, columnName | ||||
| } | ||||
| 
 | ||||
| func (mysql) DefaultValueStr() string { | ||||
| 	return "VALUES()" | ||||
| } | ||||
| @ -1,147 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type postgres struct { | ||||
| 	commonDialect | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	RegisterDialect("postgres", &postgres{}) | ||||
| 	RegisterDialect("cloudsqlpostgres", &postgres{}) | ||||
| } | ||||
| 
 | ||||
| func (postgres) GetName() string { | ||||
| 	return "postgres" | ||||
| } | ||||
| 
 | ||||
| func (postgres) BindVar(i int) string { | ||||
| 	return fmt.Sprintf("$%v", i) | ||||
| } | ||||
| 
 | ||||
| func (s *postgres) DataTypeOf(field *StructField) string { | ||||
| 	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		switch dataValue.Kind() { | ||||
| 		case reflect.Bool: | ||||
| 			sqlType = "boolean" | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "serial" | ||||
| 			} else { | ||||
| 				sqlType = "integer" | ||||
| 			} | ||||
| 		case reflect.Int64, reflect.Uint32, reflect.Uint64: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "bigserial" | ||||
| 			} else { | ||||
| 				sqlType = "bigint" | ||||
| 			} | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			sqlType = "numeric" | ||||
| 		case reflect.String: | ||||
| 			if _, ok := field.TagSettingsGet("SIZE"); !ok { | ||||
| 				size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
 | ||||
| 			} | ||||
| 
 | ||||
| 			if size > 0 && size < 65532 { | ||||
| 				sqlType = fmt.Sprintf("varchar(%d)", size) | ||||
| 			} else { | ||||
| 				sqlType = "text" | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			if _, ok := dataValue.Interface().(time.Time); ok { | ||||
| 				sqlType = "timestamp with time zone" | ||||
| 			} | ||||
| 		case reflect.Map: | ||||
| 			if dataValue.Type().Name() == "Hstore" { | ||||
| 				sqlType = "hstore" | ||||
| 			} | ||||
| 		default: | ||||
| 			if IsByteArrayOrSlice(dataValue) { | ||||
| 				sqlType = "bytea" | ||||
| 
 | ||||
| 				if isUUID(dataValue) { | ||||
| 					sqlType = "uuid" | ||||
| 				} | ||||
| 
 | ||||
| 				if isJSON(dataValue) { | ||||
| 					sqlType = "jsonb" | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.TrimSpace(additionalType) == "" { | ||||
| 		return sqlType | ||||
| 	} | ||||
| 	return fmt.Sprintf("%v %v", sqlType, additionalType) | ||||
| } | ||||
| 
 | ||||
| func (s postgres) HasIndex(tableName string, indexName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s postgres) HasTable(tableName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s postgres) HasColumn(tableName string, columnName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s postgres) CurrentDatabase() (name string) { | ||||
| 	s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { | ||||
| 	return fmt.Sprintf("RETURNING %v.%v", tableName, key) | ||||
| } | ||||
| 
 | ||||
| func (postgres) SupportLastInsertID() bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func isUUID(value reflect.Value) bool { | ||||
| 	if value.Kind() != reflect.Array || value.Type().Len() != 16 { | ||||
| 		return false | ||||
| 	} | ||||
| 	typename := value.Type().Name() | ||||
| 	lower := strings.ToLower(typename) | ||||
| 	return "uuid" == lower || "guid" == lower | ||||
| } | ||||
| 
 | ||||
| func isJSON(value reflect.Value) bool { | ||||
| 	_, ok := value.Interface().(json.RawMessage) | ||||
| 	return ok | ||||
| } | ||||
| @ -1,107 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type sqlite3 struct { | ||||
| 	commonDialect | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	RegisterDialect("sqlite3", &sqlite3{}) | ||||
| } | ||||
| 
 | ||||
| func (sqlite3) GetName() string { | ||||
| 	return "sqlite3" | ||||
| } | ||||
| 
 | ||||
| // Get Data Type for Sqlite Dialect
 | ||||
| func (s *sqlite3) DataTypeOf(field *StructField) string { | ||||
| 	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		switch dataValue.Kind() { | ||||
| 		case reflect.Bool: | ||||
| 			sqlType = "bool" | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "integer primary key autoincrement" | ||||
| 			} else { | ||||
| 				sqlType = "integer" | ||||
| 			} | ||||
| 		case reflect.Int64, reflect.Uint64: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "integer primary key autoincrement" | ||||
| 			} else { | ||||
| 				sqlType = "bigint" | ||||
| 			} | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			sqlType = "real" | ||||
| 		case reflect.String: | ||||
| 			if size > 0 && size < 65532 { | ||||
| 				sqlType = fmt.Sprintf("varchar(%d)", size) | ||||
| 			} else { | ||||
| 				sqlType = "text" | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			if _, ok := dataValue.Interface().(time.Time); ok { | ||||
| 				sqlType = "datetime" | ||||
| 			} | ||||
| 		default: | ||||
| 			if IsByteArrayOrSlice(dataValue) { | ||||
| 				sqlType = "blob" | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.TrimSpace(additionalType) == "" { | ||||
| 		return sqlType | ||||
| 	} | ||||
| 	return fmt.Sprintf("%v %v", sqlType, additionalType) | ||||
| } | ||||
| 
 | ||||
| func (s sqlite3) HasIndex(tableName string, indexName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s sqlite3) HasTable(tableName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s sqlite3) HasColumn(tableName string, columnName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s sqlite3) CurrentDatabase() (name string) { | ||||
| 	var ( | ||||
| 		ifaces   = make([]interface{}, 3) | ||||
| 		pointers = make([]*string, 3) | ||||
| 		i        int | ||||
| 	) | ||||
| 	for i = 0; i < 3; i++ { | ||||
| 		ifaces[i] = &pointers[i] | ||||
| 	} | ||||
| 	if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if pointers[1] != nil { | ||||
| 		name = *pointers[1] | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| @ -1,253 +0,0 @@ | ||||
| package mssql | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	// Importing mssql driver package only in dialect file, otherwide not needed
 | ||||
| 	_ "github.com/denisenkom/go-mssqldb" | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func setIdentityInsert(scope *gorm.Scope) { | ||||
| 	if scope.Dialect().GetName() == "mssql" { | ||||
| 		for _, field := range scope.PrimaryFields() { | ||||
| 			if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { | ||||
| 				scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) | ||||
| 				scope.InstanceSet("mssql:identity_insert_on", true) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func turnOffIdentityInsert(scope *gorm.Scope) { | ||||
| 	if scope.Dialect().GetName() == "mssql" { | ||||
| 		if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { | ||||
| 			scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) | ||||
| 	gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) | ||||
| 	gorm.RegisterDialect("mssql", &mssql{}) | ||||
| } | ||||
| 
 | ||||
| type mssql struct { | ||||
| 	db gorm.SQLCommon | ||||
| 	gorm.DefaultForeignKeyNamer | ||||
| } | ||||
| 
 | ||||
| func (mssql) GetName() string { | ||||
| 	return "mssql" | ||||
| } | ||||
| 
 | ||||
| func (s *mssql) SetDB(db gorm.SQLCommon) { | ||||
| 	s.db = db | ||||
| } | ||||
| 
 | ||||
| func (mssql) BindVar(i int) string { | ||||
| 	return "$$$" // ?
 | ||||
| } | ||||
| 
 | ||||
| func (mssql) Quote(key string) string { | ||||
| 	return fmt.Sprintf(`[%s]`, key) | ||||
| } | ||||
| 
 | ||||
| func (s *mssql) DataTypeOf(field *gorm.StructField) string { | ||||
| 	var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		switch dataValue.Kind() { | ||||
| 		case reflect.Bool: | ||||
| 			sqlType = "bit" | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "int IDENTITY(1,1)" | ||||
| 			} else { | ||||
| 				sqlType = "int" | ||||
| 			} | ||||
| 		case reflect.Int64, reflect.Uint64: | ||||
| 			if s.fieldCanAutoIncrement(field) { | ||||
| 				field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") | ||||
| 				sqlType = "bigint IDENTITY(1,1)" | ||||
| 			} else { | ||||
| 				sqlType = "bigint" | ||||
| 			} | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			sqlType = "float" | ||||
| 		case reflect.String: | ||||
| 			if size > 0 && size < 8000 { | ||||
| 				sqlType = fmt.Sprintf("nvarchar(%d)", size) | ||||
| 			} else { | ||||
| 				sqlType = "nvarchar(max)" | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			if _, ok := dataValue.Interface().(time.Time); ok { | ||||
| 				sqlType = "datetimeoffset" | ||||
| 			} | ||||
| 		default: | ||||
| 			if gorm.IsByteArrayOrSlice(dataValue) { | ||||
| 				if size > 0 && size < 8000 { | ||||
| 					sqlType = fmt.Sprintf("varbinary(%d)", size) | ||||
| 				} else { | ||||
| 					sqlType = "varbinary(max)" | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlType == "" { | ||||
| 		panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.TrimSpace(additionalType) == "" { | ||||
| 		return sqlType | ||||
| 	} | ||||
| 	return fmt.Sprintf("%v %v", sqlType, additionalType) | ||||
| } | ||||
| 
 | ||||
| func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { | ||||
| 	if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { | ||||
| 		return value != "FALSE" | ||||
| 	} | ||||
| 	return field.IsPrimaryKey | ||||
| } | ||||
| 
 | ||||
| func (s mssql) HasIndex(tableName string, indexName string) bool { | ||||
| 	var count int | ||||
| 	s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s mssql) RemoveIndex(tableName string, indexName string) error { | ||||
| 	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow(`SELECT count(*)  | ||||
| 	FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id  | ||||
| 		inner join information_schema.tables as I on I.TABLE_NAME = T.name  | ||||
| 	WHERE F.name = ?  | ||||
| 		AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s mssql) HasTable(tableName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s mssql) HasColumn(tableName string, columnName string) bool { | ||||
| 	var count int | ||||
| 	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) | ||||
| 	s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { | ||||
| 	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s mssql) CurrentDatabase() (name string) { | ||||
| 	s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func parseInt(value interface{}) (int64, error) { | ||||
| 	return strconv.ParseInt(fmt.Sprint(value), 0, 0) | ||||
| } | ||||
| 
 | ||||
| func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { | ||||
| 	if offset != nil { | ||||
| 		if parsedOffset, err := parseInt(offset); err != nil { | ||||
| 			return "", err | ||||
| 		} else if parsedOffset >= 0 { | ||||
| 			sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) | ||||
| 		} | ||||
| 	} | ||||
| 	if limit != nil { | ||||
| 		if parsedLimit, err := parseInt(limit); err != nil { | ||||
| 			return "", err | ||||
| 		} else if parsedLimit >= 0 { | ||||
| 			if sql == "" { | ||||
| 				// add default zero offset
 | ||||
| 				sql += " OFFSET 0 ROWS" | ||||
| 			} | ||||
| 			sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (mssql) SelectFromDummyTable() string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { | ||||
| 	if len(columns) == 0 { | ||||
| 		// No OUTPUT to query
 | ||||
| 		return "" | ||||
| 	} | ||||
| 	return fmt.Sprintf("OUTPUT Inserted.%v", columnName) | ||||
| } | ||||
| 
 | ||||
| func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { | ||||
| 	// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
 | ||||
| 	return "; SELECT SCOPE_IDENTITY()" | ||||
| } | ||||
| 
 | ||||
| func (mssql) DefaultValueStr() string { | ||||
| 	return "DEFAULT VALUES" | ||||
| } | ||||
| 
 | ||||
| // NormalizeIndexAndColumn returns argument's index name and column name without doing anything
 | ||||
| func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { | ||||
| 	return indexName, columnName | ||||
| } | ||||
| 
 | ||||
| func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { | ||||
| 	if strings.Contains(tableName, ".") { | ||||
| 		splitStrings := strings.SplitN(tableName, ".", 2) | ||||
| 		return splitStrings[0], splitStrings[1] | ||||
| 	} | ||||
| 	return dialect.CurrentDatabase(), tableName | ||||
| } | ||||
| 
 | ||||
| // JSON type to support easy handling of JSON data in character table fields
 | ||||
| // using golang json.RawMessage for deferred decoding/encoding
 | ||||
| type JSON struct { | ||||
| 	json.RawMessage | ||||
| } | ||||
| 
 | ||||
| // Value get value of JSON
 | ||||
| func (j JSON) Value() (driver.Value, error) { | ||||
| 	if len(j.RawMessage) == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	return j.MarshalJSON() | ||||
| } | ||||
| 
 | ||||
| // Scan scan value into JSON
 | ||||
| func (j *JSON) Scan(value interface{}) error { | ||||
| 	str, ok := value.(string) | ||||
| 	if !ok { | ||||
| 		return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) | ||||
| 	} | ||||
| 	bytes := []byte(str) | ||||
| 	return json.Unmarshal(bytes, j) | ||||
| } | ||||
| @ -1,3 +0,0 @@ | ||||
| package mysql | ||||
| 
 | ||||
| import _ "github.com/go-sql-driver/mysql" | ||||
| @ -1,81 +0,0 @@ | ||||
| package postgres | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 
 | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	_ "github.com/lib/pq" | ||||
| 	"github.com/lib/pq/hstore" | ||||
| ) | ||||
| 
 | ||||
| type Hstore map[string]*string | ||||
| 
 | ||||
| // Value get value of Hstore
 | ||||
| func (h Hstore) Value() (driver.Value, error) { | ||||
| 	hstore := hstore.Hstore{Map: map[string]sql.NullString{}} | ||||
| 	if len(h) == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	for key, value := range h { | ||||
| 		var s sql.NullString | ||||
| 		if value != nil { | ||||
| 			s.String = *value | ||||
| 			s.Valid = true | ||||
| 		} | ||||
| 		hstore.Map[key] = s | ||||
| 	} | ||||
| 	return hstore.Value() | ||||
| } | ||||
| 
 | ||||
| // Scan scan value into Hstore
 | ||||
| func (h *Hstore) Scan(value interface{}) error { | ||||
| 	hstore := hstore.Hstore{} | ||||
| 
 | ||||
| 	if err := hstore.Scan(value); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if len(hstore.Map) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	*h = Hstore{} | ||||
| 	for k := range hstore.Map { | ||||
| 		if hstore.Map[k].Valid { | ||||
| 			s := hstore.Map[k].String | ||||
| 			(*h)[k] = &s | ||||
| 		} else { | ||||
| 			(*h)[k] = nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Jsonb Postgresql's JSONB data type
 | ||||
| type Jsonb struct { | ||||
| 	json.RawMessage | ||||
| } | ||||
| 
 | ||||
| // Value get value of Jsonb
 | ||||
| func (j Jsonb) Value() (driver.Value, error) { | ||||
| 	if len(j.RawMessage) == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	return j.MarshalJSON() | ||||
| } | ||||
| 
 | ||||
| // Scan scan value into Jsonb
 | ||||
| func (j *Jsonb) Scan(value interface{}) error { | ||||
| 	bytes, ok := value.([]byte) | ||||
| 	if !ok { | ||||
| 		return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) | ||||
| 	} | ||||
| 
 | ||||
| 	return json.Unmarshal(bytes, j) | ||||
| } | ||||
| @ -1,3 +0,0 @@ | ||||
| package sqlite | ||||
| 
 | ||||
| import _ "github.com/mattn/go-sqlite3" | ||||
| @ -1,30 +0,0 @@ | ||||
| version: '3' | ||||
| 
 | ||||
| services: | ||||
|   mysql: | ||||
|     image: 'mysql:latest' | ||||
|     ports: | ||||
|       - 9910:3306 | ||||
|     environment: | ||||
|       - MYSQL_DATABASE=gorm | ||||
|       - MYSQL_USER=gorm | ||||
|       - MYSQL_PASSWORD=gorm | ||||
|       - MYSQL_RANDOM_ROOT_PASSWORD="yes" | ||||
|   postgres: | ||||
|     image: 'postgres:latest' | ||||
|     ports: | ||||
|       - 9920:5432 | ||||
|     environment: | ||||
|       - POSTGRES_USER=gorm | ||||
|       - POSTGRES_DB=gorm | ||||
|       - POSTGRES_PASSWORD=gorm | ||||
|   mssql: | ||||
|     image: 'mcmoe/mssqldocker:latest' | ||||
|     ports: | ||||
|       - 9930:1433 | ||||
|     environment: | ||||
|       - ACCEPT_EULA=Y | ||||
|       - SA_PASSWORD=LoremIpsum86 | ||||
|       - MSSQL_DB=gorm | ||||
|       - MSSQL_USER=gorm | ||||
|       - MSSQL_PASSWORD=LoremIpsum86 | ||||
| @ -1,91 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import "testing" | ||||
| 
 | ||||
| type BasePost struct { | ||||
| 	Id    int64 | ||||
| 	Title string | ||||
| 	URL   string | ||||
| } | ||||
| 
 | ||||
| type Author struct { | ||||
| 	ID    string | ||||
| 	Name  string | ||||
| 	Email string | ||||
| } | ||||
| 
 | ||||
| type HNPost struct { | ||||
| 	BasePost | ||||
| 	Author  `gorm:"embedded_prefix:user_"` // Embedded struct
 | ||||
| 	Upvotes int32 | ||||
| } | ||||
| 
 | ||||
| type EngadgetPost struct { | ||||
| 	BasePost BasePost `gorm:"embedded"` | ||||
| 	Author   Author   `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
 | ||||
| 	ImageUrl string | ||||
| } | ||||
| 
 | ||||
| func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { | ||||
| 	dialect := DB.NewScope(&EngadgetPost{}).Dialect() | ||||
| 	engadgetPostScope := DB.NewScope(&EngadgetPost{}) | ||||
| 	if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { | ||||
| 		t.Errorf("should has prefix for embedded columns") | ||||
| 	} | ||||
| 
 | ||||
| 	if len(engadgetPostScope.PrimaryFields()) != 1 { | ||||
| 		t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) | ||||
| 	} | ||||
| 
 | ||||
| 	hnScope := DB.NewScope(&HNPost{}) | ||||
| 	if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { | ||||
| 		t.Errorf("should has prefix for embedded columns") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSaveAndQueryEmbeddedStruct(t *testing.T) { | ||||
| 	DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) | ||||
| 	DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) | ||||
| 	var news HNPost | ||||
| 	if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { | ||||
| 		t.Errorf("no error should happen when query with embedded struct, but got %v", err) | ||||
| 	} else if news.Title != "hn_news" { | ||||
| 		t.Errorf("embedded struct's value should be scanned correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) | ||||
| 	var egNews EngadgetPost | ||||
| 	if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { | ||||
| 		t.Errorf("no error should happen when query with embedded struct, but got %v", err) | ||||
| 	} else if egNews.BasePost.Title != "engadget_news" { | ||||
| 		t.Errorf("embedded struct's value should be scanned correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.NewScope(&HNPost{}).PrimaryField() == nil { | ||||
| 		t.Errorf("primary key with embedded struct should works") | ||||
| 	} | ||||
| 
 | ||||
| 	for _, field := range DB.NewScope(&HNPost{}).Fields() { | ||||
| 		if field.Name == "BasePost" { | ||||
| 			t.Errorf("scope Fields should not contain embedded struct") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEmbeddedPointerTypeStruct(t *testing.T) { | ||||
| 	type HNPost struct { | ||||
| 		*BasePost | ||||
| 		Upvotes int32 | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) | ||||
| 
 | ||||
| 	var hnPost HNPost | ||||
| 	if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { | ||||
| 		t.Errorf("No error should happen when find embedded pointer type, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if hnPost.Title != "embedded_pointer_type" { | ||||
| 		t.Errorf("Should find correct value for embedded pointer type") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										72
									
								
								errors.go
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								errors.go
									
									
									
									
									
								
							| @ -1,72 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	// ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
 | ||||
| 	ErrRecordNotFound = errors.New("record not found") | ||||
| 	// ErrInvalidSQL occurs when you attempt a query with invalid SQL
 | ||||
| 	ErrInvalidSQL = errors.New("invalid SQL") | ||||
| 	// ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
 | ||||
| 	ErrInvalidTransaction = errors.New("no valid transaction") | ||||
| 	// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
 | ||||
| 	ErrCantStartTransaction = errors.New("can't start transaction") | ||||
| 	// ErrUnaddressable unaddressable value
 | ||||
| 	ErrUnaddressable = errors.New("using unaddressable value") | ||||
| ) | ||||
| 
 | ||||
| // Errors contains all happened errors
 | ||||
| type Errors []error | ||||
| 
 | ||||
| // IsRecordNotFoundError returns true if error contains a RecordNotFound error
 | ||||
| func IsRecordNotFoundError(err error) bool { | ||||
| 	if errs, ok := err.(Errors); ok { | ||||
| 		for _, err := range errs { | ||||
| 			if err == ErrRecordNotFound { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return err == ErrRecordNotFound | ||||
| } | ||||
| 
 | ||||
| // GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
 | ||||
| func (errs Errors) GetErrors() []error { | ||||
| 	return errs | ||||
| } | ||||
| 
 | ||||
| // Add adds an error to a given slice of errors
 | ||||
| func (errs Errors) Add(newErrors ...error) Errors { | ||||
| 	for _, err := range newErrors { | ||||
| 		if err == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if errors, ok := err.(Errors); ok { | ||||
| 			errs = errs.Add(errors...) | ||||
| 		} else { | ||||
| 			ok = true | ||||
| 			for _, e := range errs { | ||||
| 				if err == e { | ||||
| 					ok = false | ||||
| 				} | ||||
| 			} | ||||
| 			if ok { | ||||
| 				errs = append(errs, err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return errs | ||||
| } | ||||
| 
 | ||||
| // Error takes a slice of all errors that have occurred and returns it as a formatted string
 | ||||
| func (errs Errors) Error() string { | ||||
| 	var errors = []string{} | ||||
| 	for _, e := range errs { | ||||
| 		errors = append(errors, e.Error()) | ||||
| 	} | ||||
| 	return strings.Join(errors, "; ") | ||||
| } | ||||
| @ -1,20 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func TestErrorsCanBeUsedOutsideGorm(t *testing.T) { | ||||
| 	errs := []error{errors.New("First"), errors.New("Second")} | ||||
| 
 | ||||
| 	gErrs := gorm.Errors(errs) | ||||
| 	gErrs = gErrs.Add(errors.New("Third")) | ||||
| 	gErrs = gErrs.Add(gErrs) | ||||
| 
 | ||||
| 	if gErrs.Error() != "First; Second; Third" { | ||||
| 		t.Fatalf("Gave wrong error, got %s", gErrs.Error()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										66
									
								
								field.go
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								field.go
									
									
									
									
									
								
							| @ -1,66 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // Field model field definition
 | ||||
| type Field struct { | ||||
| 	*StructField | ||||
| 	IsBlank bool | ||||
| 	Field   reflect.Value | ||||
| } | ||||
| 
 | ||||
| // Set set a value to the field
 | ||||
| func (field *Field) Set(value interface{}) (err error) { | ||||
| 	if !field.Field.IsValid() { | ||||
| 		return errors.New("field value not valid") | ||||
| 	} | ||||
| 
 | ||||
| 	if !field.Field.CanAddr() { | ||||
| 		return ErrUnaddressable | ||||
| 	} | ||||
| 
 | ||||
| 	reflectValue, ok := value.(reflect.Value) | ||||
| 	if !ok { | ||||
| 		reflectValue = reflect.ValueOf(value) | ||||
| 	} | ||||
| 
 | ||||
| 	fieldValue := field.Field | ||||
| 	if reflectValue.IsValid() { | ||||
| 		if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { | ||||
| 			fieldValue.Set(reflectValue.Convert(fieldValue.Type())) | ||||
| 		} else { | ||||
| 			if fieldValue.Kind() == reflect.Ptr { | ||||
| 				if fieldValue.IsNil() { | ||||
| 					fieldValue.Set(reflect.New(field.Struct.Type.Elem())) | ||||
| 				} | ||||
| 				fieldValue = fieldValue.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { | ||||
| 				fieldValue.Set(reflectValue.Convert(fieldValue.Type())) | ||||
| 			} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { | ||||
| 				v := reflectValue.Interface() | ||||
| 				if valuer, ok := v.(driver.Valuer); ok { | ||||
| 					if v, err = valuer.Value(); err == nil { | ||||
| 						err = scanner.Scan(v) | ||||
| 					} | ||||
| 				} else { | ||||
| 					err = scanner.Scan(v) | ||||
| 				} | ||||
| 			} else { | ||||
| 				err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		field.Field.Set(reflect.Zero(field.Field.Type())) | ||||
| 	} | ||||
| 
 | ||||
| 	field.IsBlank = isBlank(field.Field) | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										130
									
								
								field_test.go
									
									
									
									
									
								
							
							
						
						
									
										130
									
								
								field_test.go
									
									
									
									
									
								
							| @ -1,130 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/hex" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| type CalculateField struct { | ||||
| 	gorm.Model | ||||
| 	Name     string | ||||
| 	Children []CalculateFieldChild | ||||
| 	Category CalculateFieldCategory | ||||
| 	EmbeddedField | ||||
| } | ||||
| 
 | ||||
| type EmbeddedField struct { | ||||
| 	EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` | ||||
| } | ||||
| 
 | ||||
| type CalculateFieldChild struct { | ||||
| 	gorm.Model | ||||
| 	CalculateFieldID uint | ||||
| 	Name             string | ||||
| } | ||||
| 
 | ||||
| type CalculateFieldCategory struct { | ||||
| 	gorm.Model | ||||
| 	CalculateFieldID uint | ||||
| 	Name             string | ||||
| } | ||||
| 
 | ||||
| func TestCalculateField(t *testing.T) { | ||||
| 	var field CalculateField | ||||
| 	var scope = DB.NewScope(&field) | ||||
| 	if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { | ||||
| 		t.Errorf("Should calculate fields correctly for the first time") | ||||
| 	} | ||||
| 
 | ||||
| 	if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { | ||||
| 		t.Errorf("Should calculate fields correctly for the first time") | ||||
| 	} | ||||
| 
 | ||||
| 	if field, ok := scope.FieldByName("embedded_name"); !ok { | ||||
| 		t.Errorf("should find embedded field") | ||||
| 	} else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { | ||||
| 		t.Errorf("should find embedded field's tag settings") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type UUID [16]byte | ||||
| 
 | ||||
| type NullUUID struct { | ||||
| 	UUID | ||||
| 	Valid bool | ||||
| } | ||||
| 
 | ||||
| func FromString(input string) (u UUID) { | ||||
| 	src := []byte(input) | ||||
| 	return FromBytes(src) | ||||
| } | ||||
| 
 | ||||
| func FromBytes(src []byte) (u UUID) { | ||||
| 	dst := u[:] | ||||
| 	hex.Decode(dst[0:4], src[0:8]) | ||||
| 	hex.Decode(dst[4:6], src[9:13]) | ||||
| 	hex.Decode(dst[6:8], src[14:18]) | ||||
| 	hex.Decode(dst[8:10], src[19:23]) | ||||
| 	hex.Decode(dst[10:], src[24:]) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (u UUID) String() string { | ||||
| 	buf := make([]byte, 36) | ||||
| 	src := u[:] | ||||
| 	hex.Encode(buf[0:8], src[0:4]) | ||||
| 	buf[8] = '-' | ||||
| 	hex.Encode(buf[9:13], src[4:6]) | ||||
| 	buf[13] = '-' | ||||
| 	hex.Encode(buf[14:18], src[6:8]) | ||||
| 	buf[18] = '-' | ||||
| 	hex.Encode(buf[19:23], src[8:10]) | ||||
| 	buf[23] = '-' | ||||
| 	hex.Encode(buf[24:], src[10:]) | ||||
| 	return string(buf) | ||||
| } | ||||
| 
 | ||||
| func (u UUID) Value() (driver.Value, error) { | ||||
| 	return u.String(), nil | ||||
| } | ||||
| 
 | ||||
| func (u *UUID) Scan(src interface{}) error { | ||||
| 	switch src := src.(type) { | ||||
| 	case UUID: // support gorm convert from UUID to NullUUID
 | ||||
| 		*u = src | ||||
| 		return nil | ||||
| 	case []byte: | ||||
| 		*u = FromBytes(src) | ||||
| 		return nil | ||||
| 	case string: | ||||
| 		*u = FromString(src) | ||||
| 		return nil | ||||
| 	} | ||||
| 	return fmt.Errorf("uuid: cannot convert %T to UUID", src) | ||||
| } | ||||
| 
 | ||||
| func (u *NullUUID) Scan(src interface{}) error { | ||||
| 	u.Valid = true | ||||
| 	return u.UUID.Scan(src) | ||||
| } | ||||
| 
 | ||||
| func TestFieldSet(t *testing.T) { | ||||
| 	type TestFieldSetNullUUID struct { | ||||
| 		NullUUID NullUUID | ||||
| 	} | ||||
| 	scope := DB.NewScope(&TestFieldSetNullUUID{}) | ||||
| 	field := scope.Fields()[0] | ||||
| 	err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00")) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok { | ||||
| 		t.Fatal() | ||||
| 	} else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { | ||||
| 		t.Fatal(id) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										13
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								go.mod
									
									
									
									
									
								
							| @ -1,14 +1 @@ | ||||
| module github.com/jinzhu/gorm | ||||
| 
 | ||||
| go 1.12 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd | ||||
| 	github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 | ||||
| 	github.com/go-sql-driver/mysql v1.5.0 | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
| 	github.com/jinzhu/now v1.0.1 | ||||
| 	github.com/lib/pq v1.1.1 | ||||
| 	github.com/mattn/go-sqlite3 v2.0.1+incompatible | ||||
| 	golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect | ||||
| ) | ||||
|  | ||||
							
								
								
									
										25
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								go.sum
									
									
									
									
									
								
							| @ -1,25 +0,0 @@ | ||||
| github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= | ||||
| github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= | ||||
| github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= | ||||
| github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= | ||||
| github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= | ||||
| github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= | ||||
| github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= | ||||
| github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= | ||||
| github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= | ||||
| github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= | ||||
| github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= | ||||
| github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= | ||||
| golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= | ||||
| golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | ||||
| golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
							
								
								
									
										24
									
								
								interface.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								interface.go
									
									
									
									
									
								
							| @ -1,24 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| ) | ||||
| 
 | ||||
| // SQLCommon is the minimal database connection functionality gorm requires.  Implemented by *sql.DB.
 | ||||
| type SQLCommon interface { | ||||
| 	Exec(query string, args ...interface{}) (sql.Result, error) | ||||
| 	Prepare(query string) (*sql.Stmt, error) | ||||
| 	Query(query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRow(query string, args ...interface{}) *sql.Row | ||||
| } | ||||
| 
 | ||||
| type sqlDb interface { | ||||
| 	Begin() (*sql.Tx, error) | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||
| } | ||||
| 
 | ||||
| type sqlTx interface { | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
| } | ||||
| @ -1,211 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // JoinTableHandlerInterface is an interface for how to handle many2many relations
 | ||||
| type JoinTableHandlerInterface interface { | ||||
| 	// initialize join table handler
 | ||||
| 	Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) | ||||
| 	// Table return join table's table name
 | ||||
| 	Table(db *DB) string | ||||
| 	// Add create relationship in join table for source and destination
 | ||||
| 	Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error | ||||
| 	// Delete delete relationship in join table for sources
 | ||||
| 	Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error | ||||
| 	// JoinWith query with `Join` conditions
 | ||||
| 	JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB | ||||
| 	// SourceForeignKeys return source foreign keys
 | ||||
| 	SourceForeignKeys() []JoinTableForeignKey | ||||
| 	// DestinationForeignKeys return destination foreign keys
 | ||||
| 	DestinationForeignKeys() []JoinTableForeignKey | ||||
| } | ||||
| 
 | ||||
| // JoinTableForeignKey join table foreign key struct
 | ||||
| type JoinTableForeignKey struct { | ||||
| 	DBName            string | ||||
| 	AssociationDBName string | ||||
| } | ||||
| 
 | ||||
| // JoinTableSource is a struct that contains model type and foreign keys
 | ||||
| type JoinTableSource struct { | ||||
| 	ModelType   reflect.Type | ||||
| 	ForeignKeys []JoinTableForeignKey | ||||
| } | ||||
| 
 | ||||
| // JoinTableHandler default join table handler
 | ||||
| type JoinTableHandler struct { | ||||
| 	TableName   string          `sql:"-"` | ||||
| 	Source      JoinTableSource `sql:"-"` | ||||
| 	Destination JoinTableSource `sql:"-"` | ||||
| } | ||||
| 
 | ||||
| // SourceForeignKeys return source foreign keys
 | ||||
| func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { | ||||
| 	return s.Source.ForeignKeys | ||||
| } | ||||
| 
 | ||||
| // DestinationForeignKeys return destination foreign keys
 | ||||
| func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { | ||||
| 	return s.Destination.ForeignKeys | ||||
| } | ||||
| 
 | ||||
| // Setup initialize a default join table handler
 | ||||
| func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { | ||||
| 	s.TableName = tableName | ||||
| 
 | ||||
| 	s.Source = JoinTableSource{ModelType: source} | ||||
| 	s.Source.ForeignKeys = []JoinTableForeignKey{} | ||||
| 	for idx, dbName := range relationship.ForeignFieldNames { | ||||
| 		s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ | ||||
| 			DBName:            relationship.ForeignDBNames[idx], | ||||
| 			AssociationDBName: dbName, | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	s.Destination = JoinTableSource{ModelType: destination} | ||||
| 	s.Destination.ForeignKeys = []JoinTableForeignKey{} | ||||
| 	for idx, dbName := range relationship.AssociationForeignFieldNames { | ||||
| 		s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ | ||||
| 			DBName:            relationship.AssociationForeignDBNames[idx], | ||||
| 			AssociationDBName: dbName, | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Table return join table's table name
 | ||||
| func (s JoinTableHandler) Table(db *DB) string { | ||||
| 	return DefaultTableNameHandler(db, s.TableName) | ||||
| } | ||||
| 
 | ||||
| func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { | ||||
| 	for _, source := range sources { | ||||
| 		scope := db.NewScope(source) | ||||
| 		modelType := scope.GetModelStruct().ModelType | ||||
| 
 | ||||
| 		for _, joinTableSource := range joinTableSources { | ||||
| 			if joinTableSource.ModelType == modelType { | ||||
| 				for _, foreignKey := range joinTableSource.ForeignKeys { | ||||
| 					if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { | ||||
| 						conditionMap[foreignKey.DBName] = field.Field.Interface() | ||||
| 					} | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Add create relationship in join table for source and destination
 | ||||
| func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { | ||||
| 	var ( | ||||
| 		scope        = db.NewScope("") | ||||
| 		conditionMap = map[string]interface{}{} | ||||
| 	) | ||||
| 
 | ||||
| 	// Update condition map for source
 | ||||
| 	s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) | ||||
| 
 | ||||
| 	// Update condition map for destination
 | ||||
| 	s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) | ||||
| 
 | ||||
| 	var assignColumns, binVars, conditions []string | ||||
| 	var values []interface{} | ||||
| 	for key, value := range conditionMap { | ||||
| 		assignColumns = append(assignColumns, scope.Quote(key)) | ||||
| 		binVars = append(binVars, `?`) | ||||
| 		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) | ||||
| 		values = append(values, value) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, value := range values { | ||||
| 		values = append(values, value) | ||||
| 	} | ||||
| 
 | ||||
| 	quotedTable := scope.Quote(handler.Table(db)) | ||||
| 	sql := fmt.Sprintf( | ||||
| 		"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", | ||||
| 		quotedTable, | ||||
| 		strings.Join(assignColumns, ","), | ||||
| 		strings.Join(binVars, ","), | ||||
| 		scope.Dialect().SelectFromDummyTable(), | ||||
| 		quotedTable, | ||||
| 		strings.Join(conditions, " AND "), | ||||
| 	) | ||||
| 
 | ||||
| 	return db.Exec(sql, values...).Error | ||||
| } | ||||
| 
 | ||||
| // Delete delete relationship in join table for sources
 | ||||
| func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { | ||||
| 	var ( | ||||
| 		scope        = db.NewScope(nil) | ||||
| 		conditions   []string | ||||
| 		values       []interface{} | ||||
| 		conditionMap = map[string]interface{}{} | ||||
| 	) | ||||
| 
 | ||||
| 	s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) | ||||
| 
 | ||||
| 	for key, value := range conditionMap { | ||||
| 		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) | ||||
| 		values = append(values, value) | ||||
| 	} | ||||
| 
 | ||||
| 	return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error | ||||
| } | ||||
| 
 | ||||
| // JoinWith query with `Join` conditions
 | ||||
| func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { | ||||
| 	var ( | ||||
| 		scope           = db.NewScope(source) | ||||
| 		tableName       = handler.Table(db) | ||||
| 		quotedTableName = scope.Quote(tableName) | ||||
| 		joinConditions  []string | ||||
| 		values          []interface{} | ||||
| 	) | ||||
| 
 | ||||
| 	if s.Source.ModelType == scope.GetModelStruct().ModelType { | ||||
| 		destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() | ||||
| 		for _, foreignKey := range s.Destination.ForeignKeys { | ||||
| 			joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) | ||||
| 		} | ||||
| 
 | ||||
| 		var foreignDBNames []string | ||||
| 		var foreignFieldNames []string | ||||
| 
 | ||||
| 		for _, foreignKey := range s.Source.ForeignKeys { | ||||
| 			foreignDBNames = append(foreignDBNames, foreignKey.DBName) | ||||
| 			if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { | ||||
| 				foreignFieldNames = append(foreignFieldNames, field.Name) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) | ||||
| 
 | ||||
| 		var condString string | ||||
| 		if len(foreignFieldValues) > 0 { | ||||
| 			var quotedForeignDBNames []string | ||||
| 			for _, dbName := range foreignDBNames { | ||||
| 				quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) | ||||
| 			} | ||||
| 
 | ||||
| 			condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) | ||||
| 
 | ||||
| 			keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) | ||||
| 			values = append(values, toQueryValues(keys)) | ||||
| 		} else { | ||||
| 			condString = fmt.Sprintf("1 <> 1") | ||||
| 		} | ||||
| 
 | ||||
| 		return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). | ||||
| 			Where(condString, toQueryValues(foreignFieldValues)...) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Error = errors.New("wrong source type for join table handler") | ||||
| 	return db | ||||
| } | ||||
| @ -1,117 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| type Person struct { | ||||
| 	Id        int | ||||
| 	Name      string | ||||
| 	Addresses []*Address `gorm:"many2many:person_addresses;"` | ||||
| } | ||||
| 
 | ||||
| type PersonAddress struct { | ||||
| 	gorm.JoinTableHandler | ||||
| 	PersonID  int | ||||
| 	AddressID int | ||||
| 	DeletedAt *time.Time | ||||
| 	CreatedAt time.Time | ||||
| } | ||||
| 
 | ||||
| func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { | ||||
| 	foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue())) | ||||
| 	associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue())) | ||||
| 	if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{ | ||||
| 		"person_id":  foreignPrimaryKey, | ||||
| 		"address_id": associationPrimaryKey, | ||||
| 	}).Update(map[string]interface{}{ | ||||
| 		"person_id":  foreignPrimaryKey, | ||||
| 		"address_id": associationPrimaryKey, | ||||
| 		"deleted_at": gorm.Expr("NULL"), | ||||
| 	}).RowsAffected; result == 0 { | ||||
| 		return db.Create(&PersonAddress{ | ||||
| 			PersonID:  foreignPrimaryKey, | ||||
| 			AddressID: associationPrimaryKey, | ||||
| 		}).Error | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { | ||||
| 	return db.Delete(&PersonAddress{}).Error | ||||
| } | ||||
| 
 | ||||
| func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { | ||||
| 	table := pa.Table(db) | ||||
| 	return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) | ||||
| } | ||||
| 
 | ||||
| func TestJoinTable(t *testing.T) { | ||||
| 	DB.Exec("drop table person_addresses;") | ||||
| 	DB.AutoMigrate(&Person{}) | ||||
| 	DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) | ||||
| 
 | ||||
| 	address1 := &Address{Address1: "address 1"} | ||||
| 	address2 := &Address{Address1: "address 2"} | ||||
| 	person := &Person{Name: "person", Addresses: []*Address{address1, address2}} | ||||
| 	DB.Save(person) | ||||
| 
 | ||||
| 	DB.Model(person).Association("Addresses").Delete(address1) | ||||
| 
 | ||||
| 	if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { | ||||
| 		t.Errorf("Should found one address") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(person).Association("Addresses").Count() != 1 { | ||||
| 		t.Errorf("Should found one address") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { | ||||
| 		t.Errorf("Found two addresses with Unscoped") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { | ||||
| 		t.Errorf("Should deleted all addresses") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEmbeddedMany2ManyRelationship(t *testing.T) { | ||||
| 	type EmbeddedPerson struct { | ||||
| 		ID        int | ||||
| 		Name      string | ||||
| 		Addresses []*Address `gorm:"many2many:person_addresses;"` | ||||
| 	} | ||||
| 
 | ||||
| 	type NewPerson struct { | ||||
| 		EmbeddedPerson | ||||
| 		ExternalID uint | ||||
| 	} | ||||
| 	DB.Exec("drop table person_addresses;") | ||||
| 	DB.AutoMigrate(&NewPerson{}) | ||||
| 
 | ||||
| 	address1 := &Address{Address1: "address 1"} | ||||
| 	address2 := &Address{Address1: "address 2"} | ||||
| 	person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}} | ||||
| 	if err := DB.Save(person).Error; err != nil { | ||||
| 		t.Errorf("no error should return when save embedded many2many relationship, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil { | ||||
| 		t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	association := DB.Model(person).Association("Addresses") | ||||
| 	if count := association.Count(); count != 1 || association.Error != nil { | ||||
| 		t.Errorf("Should found one address, but got %v, error is %v", count, association.Error) | ||||
| 	} | ||||
| 
 | ||||
| 	if association.Clear(); association.Count() != 0 { | ||||
| 		t.Errorf("Should deleted all addresses") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										141
									
								
								logger.go
									
									
									
									
									
								
							
							
						
						
									
										141
									
								
								logger.go
									
									
									
									
									
								
							| @ -1,141 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| 	"unicode" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	defaultLogger            = Logger{log.New(os.Stdout, "\r\n", 0)} | ||||
| 	sqlRegexp                = regexp.MustCompile(`\?`) | ||||
| 	numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) | ||||
| ) | ||||
| 
 | ||||
| func isPrintable(s string) bool { | ||||
| 	for _, r := range s { | ||||
| 		if !unicode.IsPrint(r) { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| var LogFormatter = func(values ...interface{}) (messages []interface{}) { | ||||
| 	if len(values) > 1 { | ||||
| 		var ( | ||||
| 			sql             string | ||||
| 			formattedValues []string | ||||
| 			level           = values[0] | ||||
| 			currentTime     = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" | ||||
| 			source          = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) | ||||
| 		) | ||||
| 
 | ||||
| 		messages = []interface{}{source, currentTime} | ||||
| 
 | ||||
| 		if len(values) == 2 { | ||||
| 			//remove the line break
 | ||||
| 			currentTime = currentTime[1:] | ||||
| 			//remove the brackets
 | ||||
| 			source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) | ||||
| 
 | ||||
| 			messages = []interface{}{currentTime, source} | ||||
| 		} | ||||
| 
 | ||||
| 		if level == "sql" { | ||||
| 			// duration
 | ||||
| 			messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) | ||||
| 			// sql
 | ||||
| 
 | ||||
| 			for _, value := range values[4].([]interface{}) { | ||||
| 				indirectValue := reflect.Indirect(reflect.ValueOf(value)) | ||||
| 				if indirectValue.IsValid() { | ||||
| 					value = indirectValue.Interface() | ||||
| 					if t, ok := value.(time.Time); ok { | ||||
| 						if t.IsZero() { | ||||
| 							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) | ||||
| 						} else { | ||||
| 							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) | ||||
| 						} | ||||
| 					} else if b, ok := value.([]byte); ok { | ||||
| 						if str := string(b); isPrintable(str) { | ||||
| 							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) | ||||
| 						} else { | ||||
| 							formattedValues = append(formattedValues, "'<binary>'") | ||||
| 						} | ||||
| 					} else if r, ok := value.(driver.Valuer); ok { | ||||
| 						if value, err := r.Value(); err == nil && value != nil { | ||||
| 							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) | ||||
| 						} else { | ||||
| 							formattedValues = append(formattedValues, "NULL") | ||||
| 						} | ||||
| 					} else { | ||||
| 						switch value.(type) { | ||||
| 						case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: | ||||
| 							formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) | ||||
| 						default: | ||||
| 							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					formattedValues = append(formattedValues, "NULL") | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// differentiate between $n placeholders or else treat like ?
 | ||||
| 			if numericPlaceHolderRegexp.MatchString(values[3].(string)) { | ||||
| 				sql = values[3].(string) | ||||
| 				for index, value := range formattedValues { | ||||
| 					placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) | ||||
| 					sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") | ||||
| 				} | ||||
| 			} else { | ||||
| 				formattedValuesLength := len(formattedValues) | ||||
| 				for index, value := range sqlRegexp.Split(values[3].(string), -1) { | ||||
| 					sql += value | ||||
| 					if index < formattedValuesLength { | ||||
| 						sql += formattedValues[index] | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			messages = append(messages, sql) | ||||
| 			messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) | ||||
| 		} else { | ||||
| 			messages = append(messages, "\033[31;1m") | ||||
| 			messages = append(messages, values[2:]...) | ||||
| 			messages = append(messages, "\033[0m") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type logger interface { | ||||
| 	Print(v ...interface{}) | ||||
| } | ||||
| 
 | ||||
| // LogWriter log writer interface
 | ||||
| type LogWriter interface { | ||||
| 	Println(v ...interface{}) | ||||
| } | ||||
| 
 | ||||
| // Logger default logger
 | ||||
| type Logger struct { | ||||
| 	LogWriter | ||||
| } | ||||
| 
 | ||||
| // Print format & print log
 | ||||
| func (logger Logger) Print(values ...interface{}) { | ||||
| 	logger.Println(LogFormatter(values...)...) | ||||
| } | ||||
| 
 | ||||
| type nopLogger struct{} | ||||
| 
 | ||||
| func (nopLogger) Print(values ...interface{}) {} | ||||
							
								
								
									
										881
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										881
									
								
								main.go
									
									
									
									
									
								
							| @ -1,881 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // DB contains information for current db connection
 | ||||
| type DB struct { | ||||
| 	sync.RWMutex | ||||
| 	Value        interface{} | ||||
| 	Error        error | ||||
| 	RowsAffected int64 | ||||
| 
 | ||||
| 	// single db
 | ||||
| 	db                SQLCommon | ||||
| 	blockGlobalUpdate bool | ||||
| 	logMode           logModeValue | ||||
| 	logger            logger | ||||
| 	search            *search | ||||
| 	values            sync.Map | ||||
| 
 | ||||
| 	// global db
 | ||||
| 	parent        *DB | ||||
| 	callbacks     *Callback | ||||
| 	dialect       Dialect | ||||
| 	singularTable bool | ||||
| 
 | ||||
| 	// function to be used to override the creating of a new timestamp
 | ||||
| 	nowFuncOverride func() time.Time | ||||
| } | ||||
| 
 | ||||
| type logModeValue int | ||||
| 
 | ||||
| const ( | ||||
| 	defaultLogMode logModeValue = iota | ||||
| 	noLogMode | ||||
| 	detailedLogMode | ||||
| ) | ||||
| 
 | ||||
| // Open initialize a new db connection, need to import driver first, e.g:
 | ||||
| //
 | ||||
| //     import _ "github.com/go-sql-driver/mysql"
 | ||||
| //     func main() {
 | ||||
| //       db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
 | ||||
| //     }
 | ||||
| // GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
 | ||||
| //    import _ "github.com/jinzhu/gorm/dialects/mysql"
 | ||||
| //    // import _ "github.com/jinzhu/gorm/dialects/postgres"
 | ||||
| //    // import _ "github.com/jinzhu/gorm/dialects/sqlite"
 | ||||
| //    // import _ "github.com/jinzhu/gorm/dialects/mssql"
 | ||||
| func Open(dialect string, args ...interface{}) (db *DB, err error) { | ||||
| 	if len(args) == 0 { | ||||
| 		err = errors.New("invalid database source") | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	var source string | ||||
| 	var dbSQL SQLCommon | ||||
| 	var ownDbSQL bool | ||||
| 
 | ||||
| 	switch value := args[0].(type) { | ||||
| 	case string: | ||||
| 		var driver = dialect | ||||
| 		if len(args) == 1 { | ||||
| 			source = value | ||||
| 		} else if len(args) >= 2 { | ||||
| 			driver = value | ||||
| 			source = args[1].(string) | ||||
| 		} | ||||
| 		dbSQL, err = sql.Open(driver, source) | ||||
| 		ownDbSQL = true | ||||
| 	case SQLCommon: | ||||
| 		dbSQL = value | ||||
| 		ownDbSQL = false | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) | ||||
| 	} | ||||
| 
 | ||||
| 	db = &DB{ | ||||
| 		db:        dbSQL, | ||||
| 		logger:    defaultLogger, | ||||
| 		callbacks: DefaultCallback, | ||||
| 		dialect:   newDialect(dialect, dbSQL), | ||||
| 	} | ||||
| 	db.parent = db | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	// Send a ping to make sure the database connection is alive.
 | ||||
| 	if d, ok := dbSQL.(*sql.DB); ok { | ||||
| 		if err = d.Ping(); err != nil && ownDbSQL { | ||||
| 			d.Close() | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // New clone a new db connection without search conditions
 | ||||
| func (s *DB) New() *DB { | ||||
| 	clone := s.clone() | ||||
| 	clone.search = nil | ||||
| 	clone.Value = nil | ||||
| 	return clone | ||||
| } | ||||
| 
 | ||||
| type closer interface { | ||||
| 	Close() error | ||||
| } | ||||
| 
 | ||||
| // Close close current db connection.  If database connection is not an io.Closer, returns an error.
 | ||||
| func (s *DB) Close() error { | ||||
| 	if db, ok := s.parent.db.(closer); ok { | ||||
| 		return db.Close() | ||||
| 	} | ||||
| 	return errors.New("can't close current db") | ||||
| } | ||||
| 
 | ||||
| // DB get `*sql.DB` from current connection
 | ||||
| // If the underlying database connection is not a *sql.DB, returns nil
 | ||||
| func (s *DB) DB() *sql.DB { | ||||
| 	db, ok := s.db.(*sql.DB) | ||||
| 	if !ok { | ||||
| 		panic("can't support full GORM on currently status, maybe this is a TX instance.") | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
 | ||||
| func (s *DB) CommonDB() SQLCommon { | ||||
| 	return s.db | ||||
| } | ||||
| 
 | ||||
| // Dialect get dialect
 | ||||
| func (s *DB) Dialect() Dialect { | ||||
| 	return s.dialect | ||||
| } | ||||
| 
 | ||||
| // Callback return `Callbacks` container, you could add/change/delete callbacks with it
 | ||||
| //     db.Callback().Create().Register("update_created_at", updateCreated)
 | ||||
| // Refer https://jinzhu.github.io/gorm/development.html#callbacks
 | ||||
| func (s *DB) Callback() *Callback { | ||||
| 	s.parent.callbacks = s.parent.callbacks.clone(s.logger) | ||||
| 	return s.parent.callbacks | ||||
| } | ||||
| 
 | ||||
| // SetLogger replace default logger
 | ||||
| func (s *DB) SetLogger(log logger) { | ||||
| 	s.logger = log | ||||
| } | ||||
| 
 | ||||
| // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
 | ||||
| func (s *DB) LogMode(enable bool) *DB { | ||||
| 	if enable { | ||||
| 		s.logMode = detailedLogMode | ||||
| 	} else { | ||||
| 		s.logMode = noLogMode | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // SetNowFuncOverride set the function to be used when creating a new timestamp
 | ||||
| func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { | ||||
| 	s.nowFuncOverride = nowFuncOverride | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
 | ||||
| // otherwise defaults to the global NowFunc()
 | ||||
| func (s *DB) nowFunc() time.Time { | ||||
| 	if s.nowFuncOverride != nil { | ||||
| 		return s.nowFuncOverride() | ||||
| 	} | ||||
| 
 | ||||
| 	return NowFunc() | ||||
| } | ||||
| 
 | ||||
| // BlockGlobalUpdate if true, generates an error on update/delete without where clause.
 | ||||
| // This is to prevent eventual error with empty objects updates/deletions
 | ||||
| func (s *DB) BlockGlobalUpdate(enable bool) *DB { | ||||
| 	s.blockGlobalUpdate = enable | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // HasBlockGlobalUpdate return state of block
 | ||||
| func (s *DB) HasBlockGlobalUpdate() bool { | ||||
| 	return s.blockGlobalUpdate | ||||
| } | ||||
| 
 | ||||
| // SingularTable use singular table by default
 | ||||
| func (s *DB) SingularTable(enable bool) { | ||||
| 	s.parent.Lock() | ||||
| 	defer s.parent.Unlock() | ||||
| 	s.parent.singularTable = enable | ||||
| } | ||||
| 
 | ||||
| // NewScope create a scope for current operation
 | ||||
| func (s *DB) NewScope(value interface{}) *Scope { | ||||
| 	dbClone := s.clone() | ||||
| 	dbClone.Value = value | ||||
| 	scope := &Scope{db: dbClone, Value: value} | ||||
| 	if s.search != nil { | ||||
| 		scope.Search = s.search.clone() | ||||
| 	} else { | ||||
| 		scope.Search = &search{} | ||||
| 	} | ||||
| 	return scope | ||||
| } | ||||
| 
 | ||||
| // QueryExpr returns the query as SqlExpr object
 | ||||
| func (s *DB) QueryExpr() *SqlExpr { | ||||
| 	scope := s.NewScope(s.Value) | ||||
| 	scope.InstanceSet("skip_bindvar", true) | ||||
| 	scope.prepareQuerySQL() | ||||
| 
 | ||||
| 	return Expr(scope.SQL, scope.SQLVars...) | ||||
| } | ||||
| 
 | ||||
| // SubQuery returns the query as sub query
 | ||||
| func (s *DB) SubQuery() *SqlExpr { | ||||
| 	scope := s.NewScope(s.Value) | ||||
| 	scope.InstanceSet("skip_bindvar", true) | ||||
| 	scope.prepareQuerySQL() | ||||
| 
 | ||||
| 	return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) | ||||
| } | ||||
| 
 | ||||
| // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
 | ||||
| func (s *DB) Where(query interface{}, args ...interface{}) *DB { | ||||
| 	return s.clone().search.Where(query, args...).db | ||||
| } | ||||
| 
 | ||||
| // Or filter records that match before conditions or this one, similar to `Where`
 | ||||
| func (s *DB) Or(query interface{}, args ...interface{}) *DB { | ||||
| 	return s.clone().search.Or(query, args...).db | ||||
| } | ||||
| 
 | ||||
| // Not filter records that don't match current conditions, similar to `Where`
 | ||||
| func (s *DB) Not(query interface{}, args ...interface{}) *DB { | ||||
| 	return s.clone().search.Not(query, args...).db | ||||
| } | ||||
| 
 | ||||
| // Limit specify the number of records to be retrieved
 | ||||
| func (s *DB) Limit(limit interface{}) *DB { | ||||
| 	return s.clone().search.Limit(limit).db | ||||
| } | ||||
| 
 | ||||
| // Offset specify the number of records to skip before starting to return the records
 | ||||
| func (s *DB) Offset(offset interface{}) *DB { | ||||
| 	return s.clone().search.Offset(offset).db | ||||
| } | ||||
| 
 | ||||
| // Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
 | ||||
| //     db.Order("name DESC")
 | ||||
| //     db.Order("name DESC", true) // reorder
 | ||||
| //     db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
 | ||||
| func (s *DB) Order(value interface{}, reorder ...bool) *DB { | ||||
| 	return s.clone().search.Order(value, reorder...).db | ||||
| } | ||||
| 
 | ||||
| // Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
 | ||||
| // When creating/updating, specify fields that you want to save to database
 | ||||
| func (s *DB) Select(query interface{}, args ...interface{}) *DB { | ||||
| 	return s.clone().search.Select(query, args...).db | ||||
| } | ||||
| 
 | ||||
| // Omit specify fields that you want to ignore when saving to database for creating, updating
 | ||||
| func (s *DB) Omit(columns ...string) *DB { | ||||
| 	return s.clone().search.Omit(columns...).db | ||||
| } | ||||
| 
 | ||||
| // Group specify the group method on the find
 | ||||
| func (s *DB) Group(query string) *DB { | ||||
| 	return s.clone().search.Group(query).db | ||||
| } | ||||
| 
 | ||||
| // Having specify HAVING conditions for GROUP BY
 | ||||
| func (s *DB) Having(query interface{}, values ...interface{}) *DB { | ||||
| 	return s.clone().search.Having(query, values...).db | ||||
| } | ||||
| 
 | ||||
| // Joins specify Joins conditions
 | ||||
| //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | ||||
| func (s *DB) Joins(query string, args ...interface{}) *DB { | ||||
| 	return s.clone().search.Joins(query, args...).db | ||||
| } | ||||
| 
 | ||||
| // Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
 | ||||
| //     func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
 | ||||
| //         return db.Where("amount > ?", 1000)
 | ||||
| //     }
 | ||||
| //
 | ||||
| //     func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
 | ||||
| //         return func (db *gorm.DB) *gorm.DB {
 | ||||
| //             return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
 | ||||
| //         }
 | ||||
| //     }
 | ||||
| //
 | ||||
| //     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
 | ||||
| // Refer https://jinzhu.github.io/gorm/crud.html#scopes
 | ||||
| func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { | ||||
| 	for _, f := range funcs { | ||||
| 		s = f(s) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete
 | ||||
| func (s *DB) Unscoped() *DB { | ||||
| 	return s.clone().search.unscoped().db | ||||
| } | ||||
| 
 | ||||
| // Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
 | ||||
| func (s *DB) Attrs(attrs ...interface{}) *DB { | ||||
| 	return s.clone().search.Attrs(attrs...).db | ||||
| } | ||||
| 
 | ||||
| // Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
 | ||||
| func (s *DB) Assign(attrs ...interface{}) *DB { | ||||
| 	return s.clone().search.Assign(attrs...).db | ||||
| } | ||||
| 
 | ||||
| // First find first record that match given conditions, order by primary key
 | ||||
| func (s *DB) First(out interface{}, where ...interface{}) *DB { | ||||
| 	newScope := s.NewScope(out) | ||||
| 	newScope.Search.Limit(1) | ||||
| 
 | ||||
| 	return newScope.Set("gorm:order_by_primary_key", "ASC"). | ||||
| 		inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | ||||
| } | ||||
| 
 | ||||
| // Take return a record that match given conditions, the order will depend on the database implementation
 | ||||
| func (s *DB) Take(out interface{}, where ...interface{}) *DB { | ||||
| 	newScope := s.NewScope(out) | ||||
| 	newScope.Search.Limit(1) | ||||
| 	return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | ||||
| } | ||||
| 
 | ||||
| // Last find last record that match given conditions, order by primary key
 | ||||
| func (s *DB) Last(out interface{}, where ...interface{}) *DB { | ||||
| 	newScope := s.NewScope(out) | ||||
| 	newScope.Search.Limit(1) | ||||
| 	return newScope.Set("gorm:order_by_primary_key", "DESC"). | ||||
| 		inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | ||||
| } | ||||
| 
 | ||||
| // Find find records that match given conditions
 | ||||
| func (s *DB) Find(out interface{}, where ...interface{}) *DB { | ||||
| 	return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db | ||||
| } | ||||
| 
 | ||||
| //Preloads preloads relations, don`t touch out
 | ||||
| func (s *DB) Preloads(out interface{}) *DB { | ||||
| 	return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db | ||||
| } | ||||
| 
 | ||||
| // Scan scan value to a struct
 | ||||
| func (s *DB) Scan(dest interface{}) *DB { | ||||
| 	return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db | ||||
| } | ||||
| 
 | ||||
| // Row return `*sql.Row` with given conditions
 | ||||
| func (s *DB) Row() *sql.Row { | ||||
| 	return s.NewScope(s.Value).row() | ||||
| } | ||||
| 
 | ||||
| // Rows return `*sql.Rows` with given conditions
 | ||||
| func (s *DB) Rows() (*sql.Rows, error) { | ||||
| 	return s.NewScope(s.Value).rows() | ||||
| } | ||||
| 
 | ||||
| // ScanRows scan `*sql.Rows` to give struct
 | ||||
| func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { | ||||
| 	var ( | ||||
| 		scope        = s.NewScope(result) | ||||
| 		clone        = scope.db | ||||
| 		columns, err = rows.Columns() | ||||
| 	) | ||||
| 
 | ||||
| 	if clone.AddError(err) == nil { | ||||
| 		scope.scan(rows, columns, scope.Fields()) | ||||
| 	} | ||||
| 
 | ||||
| 	return clone.Error | ||||
| } | ||||
| 
 | ||||
| // Pluck used to query single column from a model as a map
 | ||||
| //     var ages []int64
 | ||||
| //     db.Find(&users).Pluck("age", &ages)
 | ||||
| func (s *DB) Pluck(column string, value interface{}) *DB { | ||||
| 	return s.NewScope(s.Value).pluck(column, value).db | ||||
| } | ||||
| 
 | ||||
| // Count get how many records for a model
 | ||||
| func (s *DB) Count(value interface{}) *DB { | ||||
| 	return s.NewScope(s.Value).count(value).db | ||||
| } | ||||
| 
 | ||||
| // Related get related associations
 | ||||
| func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { | ||||
| 	return s.NewScope(s.Value).related(value, foreignKeys...).db | ||||
| } | ||||
| 
 | ||||
| // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
 | ||||
| // https://jinzhu.github.io/gorm/crud.html#firstorinit
 | ||||
| func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { | ||||
| 	c := s.clone() | ||||
| 	if result := c.First(out, where...); result.Error != nil { | ||||
| 		if !result.RecordNotFound() { | ||||
| 			return result | ||||
| 		} | ||||
| 		c.NewScope(out).inlineCondition(where...).initialize() | ||||
| 	} else { | ||||
| 		c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
 | ||||
| // https://jinzhu.github.io/gorm/crud.html#firstorcreate
 | ||||
| func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { | ||||
| 	c := s.clone() | ||||
| 	if result := s.First(out, where...); result.Error != nil { | ||||
| 		if !result.RecordNotFound() { | ||||
| 			return result | ||||
| 		} | ||||
| 		return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db | ||||
| 	} else if len(c.search.assignAttrs) > 0 { | ||||
| 		return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||
| // WARNING when update with struct, GORM will not update fields that with zero value
 | ||||
| func (s *DB) Update(attrs ...interface{}) *DB { | ||||
| 	return s.Updates(toSearchableMap(attrs...), true) | ||||
| } | ||||
| 
 | ||||
| // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||
| func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { | ||||
| 	return s.NewScope(s.Value). | ||||
| 		Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). | ||||
| 		InstanceSet("gorm:update_interface", values). | ||||
| 		callCallbacks(s.parent.callbacks.updates).db | ||||
| } | ||||
| 
 | ||||
| // UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||
| func (s *DB) UpdateColumn(attrs ...interface{}) *DB { | ||||
| 	return s.UpdateColumns(toSearchableMap(attrs...)) | ||||
| } | ||||
| 
 | ||||
| // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||
| func (s *DB) UpdateColumns(values interface{}) *DB { | ||||
| 	return s.NewScope(s.Value). | ||||
| 		Set("gorm:update_column", true). | ||||
| 		Set("gorm:save_associations", false). | ||||
| 		InstanceSet("gorm:update_interface", values). | ||||
| 		callCallbacks(s.parent.callbacks.updates).db | ||||
| } | ||||
| 
 | ||||
| // Save update value in database, if the value doesn't have primary key, will insert it
 | ||||
| func (s *DB) Save(value interface{}) *DB { | ||||
| 	scope := s.NewScope(value) | ||||
| 	if !scope.PrimaryKeyZero() { | ||||
| 		newDB := scope.callCallbacks(s.parent.callbacks.updates).db | ||||
| 		if newDB.Error == nil && newDB.RowsAffected == 0 { | ||||
| 			return s.New().Table(scope.TableName()).FirstOrCreate(value) | ||||
| 		} | ||||
| 		return newDB | ||||
| 	} | ||||
| 	return scope.callCallbacks(s.parent.callbacks.creates).db | ||||
| } | ||||
| 
 | ||||
| // Create insert the value into database
 | ||||
| func (s *DB) Create(value interface{}) *DB { | ||||
| 	scope := s.NewScope(value) | ||||
| 	return scope.callCallbacks(s.parent.callbacks.creates).db | ||||
| } | ||||
| 
 | ||||
| // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
 | ||||
| // WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
 | ||||
| func (s *DB) Delete(value interface{}, where ...interface{}) *DB { | ||||
| 	return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db | ||||
| } | ||||
| 
 | ||||
| // Raw use raw sql as conditions, won't run it unless invoked by other methods
 | ||||
| //    db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
 | ||||
| func (s *DB) Raw(sql string, values ...interface{}) *DB { | ||||
| 	return s.clone().search.Raw(true).Where(sql, values...).db | ||||
| } | ||||
| 
 | ||||
| // Exec execute raw sql
 | ||||
| func (s *DB) Exec(sql string, values ...interface{}) *DB { | ||||
| 	scope := s.NewScope(nil) | ||||
| 	generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) | ||||
| 	generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") | ||||
| 	scope.Raw(generatedSQL) | ||||
| 	return scope.Exec().db | ||||
| } | ||||
| 
 | ||||
| // Model specify the model you would like to run db operations
 | ||||
| //    // update all users's name to `hello`
 | ||||
| //    db.Model(&User{}).Update("name", "hello")
 | ||||
| //    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | ||||
| //    db.Model(&user).Update("name", "hello")
 | ||||
| func (s *DB) Model(value interface{}) *DB { | ||||
| 	c := s.clone() | ||||
| 	c.Value = value | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // Table specify the table you would like to run db operations
 | ||||
| func (s *DB) Table(name string) *DB { | ||||
| 	clone := s.clone() | ||||
| 	clone.search.Table(name) | ||||
| 	clone.Value = nil | ||||
| 	return clone | ||||
| } | ||||
| 
 | ||||
| // Debug start debug mode
 | ||||
| func (s *DB) Debug() *DB { | ||||
| 	return s.clone().LogMode(true) | ||||
| } | ||||
| 
 | ||||
| // Transaction start a transaction as a block,
 | ||||
| // return error will rollback, otherwise to commit.
 | ||||
| func (s *DB) Transaction(fc func(tx *DB) error) (err error) { | ||||
| 	panicked := true | ||||
| 	tx := s.Begin() | ||||
| 	defer func() { | ||||
| 		// Make sure to rollback when panic, Block error or Commit error
 | ||||
| 		if panicked || err != nil { | ||||
| 			tx.Rollback() | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	err = fc(tx) | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		err = tx.Commit().Error | ||||
| 	} | ||||
| 
 | ||||
| 	panicked = false | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Begin begins a transaction
 | ||||
| func (s *DB) Begin() *DB { | ||||
| 	return s.BeginTx(context.Background(), &sql.TxOptions{}) | ||||
| } | ||||
| 
 | ||||
| // BeginTx begins a transaction with options
 | ||||
| func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { | ||||
| 	c := s.clone() | ||||
| 	if db, ok := c.db.(sqlDb); ok && db != nil { | ||||
| 		tx, err := db.BeginTx(ctx, opts) | ||||
| 		c.db = interface{}(tx).(SQLCommon) | ||||
| 
 | ||||
| 		c.dialect.SetDB(c.db) | ||||
| 		c.AddError(err) | ||||
| 	} else { | ||||
| 		c.AddError(ErrCantStartTransaction) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // Commit commit a transaction
 | ||||
| func (s *DB) Commit() *DB { | ||||
| 	var emptySQLTx *sql.Tx | ||||
| 	if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { | ||||
| 		s.AddError(db.Commit()) | ||||
| 	} else { | ||||
| 		s.AddError(ErrInvalidTransaction) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // Rollback rollback a transaction
 | ||||
| func (s *DB) Rollback() *DB { | ||||
| 	var emptySQLTx *sql.Tx | ||||
| 	if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { | ||||
| 		if err := db.Rollback(); err != nil && err != sql.ErrTxDone { | ||||
| 			s.AddError(err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		s.AddError(ErrInvalidTransaction) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // RollbackUnlessCommitted rollback a transaction if it has not yet been
 | ||||
| // committed.
 | ||||
| func (s *DB) RollbackUnlessCommitted() *DB { | ||||
| 	var emptySQLTx *sql.Tx | ||||
| 	if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { | ||||
| 		err := db.Rollback() | ||||
| 		// Ignore the error indicating that the transaction has already
 | ||||
| 		// been committed.
 | ||||
| 		if err != sql.ErrTxDone { | ||||
| 			s.AddError(err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		s.AddError(ErrInvalidTransaction) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // NewRecord check if value's primary key is blank
 | ||||
| func (s *DB) NewRecord(value interface{}) bool { | ||||
| 	return s.NewScope(value).PrimaryKeyZero() | ||||
| } | ||||
| 
 | ||||
| // RecordNotFound check if returning ErrRecordNotFound error
 | ||||
| func (s *DB) RecordNotFound() bool { | ||||
| 	for _, err := range s.GetErrors() { | ||||
| 		if err == ErrRecordNotFound { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // CreateTable create table for models
 | ||||
| func (s *DB) CreateTable(models ...interface{}) *DB { | ||||
| 	db := s.Unscoped() | ||||
| 	for _, model := range models { | ||||
| 		db = db.NewScope(model).createTable().db | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| // DropTable drop table for models
 | ||||
| func (s *DB) DropTable(values ...interface{}) *DB { | ||||
| 	db := s.clone() | ||||
| 	for _, value := range values { | ||||
| 		if tableName, ok := value.(string); ok { | ||||
| 			db = db.Table(tableName) | ||||
| 		} | ||||
| 
 | ||||
| 		db = db.NewScope(value).dropTable().db | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| // DropTableIfExists drop table if it is exist
 | ||||
| func (s *DB) DropTableIfExists(values ...interface{}) *DB { | ||||
| 	db := s.clone() | ||||
| 	for _, value := range values { | ||||
| 		if s.HasTable(value) { | ||||
| 			db.AddError(s.DropTable(value).Error) | ||||
| 		} | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| // HasTable check has table or not
 | ||||
| func (s *DB) HasTable(value interface{}) bool { | ||||
| 	var ( | ||||
| 		scope     = s.NewScope(value) | ||||
| 		tableName string | ||||
| 	) | ||||
| 
 | ||||
| 	if name, ok := value.(string); ok { | ||||
| 		tableName = name | ||||
| 	} else { | ||||
| 		tableName = scope.TableName() | ||||
| 	} | ||||
| 
 | ||||
| 	has := scope.Dialect().HasTable(tableName) | ||||
| 	s.AddError(scope.db.Error) | ||||
| 	return has | ||||
| } | ||||
| 
 | ||||
| // AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
 | ||||
| func (s *DB) AutoMigrate(values ...interface{}) *DB { | ||||
| 	db := s.Unscoped() | ||||
| 	for _, value := range values { | ||||
| 		db = db.NewScope(value).autoMigrate().db | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| // ModifyColumn modify column to type
 | ||||
| func (s *DB) ModifyColumn(column string, typ string) *DB { | ||||
| 	scope := s.NewScope(s.Value) | ||||
| 	scope.modifyColumn(column, typ) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // DropColumn drop a column
 | ||||
| func (s *DB) DropColumn(column string) *DB { | ||||
| 	scope := s.NewScope(s.Value) | ||||
| 	scope.dropColumn(column) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // AddIndex add index for columns with given name
 | ||||
| func (s *DB) AddIndex(indexName string, columns ...string) *DB { | ||||
| 	scope := s.Unscoped().NewScope(s.Value) | ||||
| 	scope.addIndex(false, indexName, columns...) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // AddUniqueIndex add unique index for columns with given name
 | ||||
| func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { | ||||
| 	scope := s.Unscoped().NewScope(s.Value) | ||||
| 	scope.addIndex(true, indexName, columns...) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // RemoveIndex remove index with name
 | ||||
| func (s *DB) RemoveIndex(indexName string) *DB { | ||||
| 	scope := s.NewScope(s.Value) | ||||
| 	scope.removeIndex(indexName) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // AddForeignKey Add foreign key to the given scope, e.g:
 | ||||
| //     db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
 | ||||
| func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { | ||||
| 	scope := s.NewScope(s.Value) | ||||
| 	scope.addForeignKey(field, dest, onDelete, onUpdate) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // RemoveForeignKey Remove foreign key from the given scope, e.g:
 | ||||
| //     db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
 | ||||
| func (s *DB) RemoveForeignKey(field string, dest string) *DB { | ||||
| 	scope := s.clone().NewScope(s.Value) | ||||
| 	scope.removeForeignKey(field, dest) | ||||
| 	return scope.db | ||||
| } | ||||
| 
 | ||||
| // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
 | ||||
| func (s *DB) Association(column string) *Association { | ||||
| 	var err error | ||||
| 	var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) | ||||
| 
 | ||||
| 	if primaryField := scope.PrimaryField(); primaryField.IsBlank { | ||||
| 		err = errors.New("primary key can't be nil") | ||||
| 	} else { | ||||
| 		if field, ok := scope.FieldByName(column); ok { | ||||
| 			if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { | ||||
| 				err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) | ||||
| 			} else { | ||||
| 				return &Association{scope: scope, column: column, field: field} | ||||
| 			} | ||||
| 		} else { | ||||
| 			err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return &Association{Error: err} | ||||
| } | ||||
| 
 | ||||
| // Preload preload associations with given conditions
 | ||||
| //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | ||||
| func (s *DB) Preload(column string, conditions ...interface{}) *DB { | ||||
| 	return s.clone().search.Preload(column, conditions...).db | ||||
| } | ||||
| 
 | ||||
| // Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
 | ||||
| func (s *DB) Set(name string, value interface{}) *DB { | ||||
| 	return s.clone().InstantSet(name, value) | ||||
| } | ||||
| 
 | ||||
| // InstantSet instant set setting, will affect current db
 | ||||
| func (s *DB) InstantSet(name string, value interface{}) *DB { | ||||
| 	s.values.Store(name, value) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // Get get setting by name
 | ||||
| func (s *DB) Get(name string) (value interface{}, ok bool) { | ||||
| 	value, ok = s.values.Load(name) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // SetJoinTableHandler set a model's join table handler for a relation
 | ||||
| func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { | ||||
| 	scope := s.NewScope(source) | ||||
| 	for _, field := range scope.GetModelStruct().StructFields { | ||||
| 		if field.Name == column || field.DBName == column { | ||||
| 			if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { | ||||
| 				source := (&Scope{Value: source}).GetModelStruct().ModelType | ||||
| 				destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType | ||||
| 				handler.Setup(field.Relationship, many2many, source, destination) | ||||
| 				field.Relationship.JoinTableHandler = handler | ||||
| 				if table := handler.Table(s); scope.Dialect().HasTable(table) { | ||||
| 					s.Table(table).AutoMigrate(handler) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // AddError add error to the db
 | ||||
| func (s *DB) AddError(err error) error { | ||||
| 	if err != nil { | ||||
| 		if err != ErrRecordNotFound { | ||||
| 			if s.logMode == defaultLogMode { | ||||
| 				go s.print("error", fileWithLineNum(), err) | ||||
| 			} else { | ||||
| 				s.log(err) | ||||
| 			} | ||||
| 
 | ||||
| 			errors := Errors(s.GetErrors()) | ||||
| 			errors = errors.Add(err) | ||||
| 			if len(errors) > 1 { | ||||
| 				err = errors | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		s.Error = err | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // GetErrors get happened errors from the db
 | ||||
| func (s *DB) GetErrors() []error { | ||||
| 	if errs, ok := s.Error.(Errors); ok { | ||||
| 		return errs | ||||
| 	} else if s.Error != nil { | ||||
| 		return []error{s.Error} | ||||
| 	} | ||||
| 	return []error{} | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| // Private Methods For DB
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| 
 | ||||
| func (s *DB) clone() *DB { | ||||
| 	db := &DB{ | ||||
| 		db:                s.db, | ||||
| 		parent:            s.parent, | ||||
| 		logger:            s.logger, | ||||
| 		logMode:           s.logMode, | ||||
| 		Value:             s.Value, | ||||
| 		Error:             s.Error, | ||||
| 		blockGlobalUpdate: s.blockGlobalUpdate, | ||||
| 		dialect:           newDialect(s.dialect.GetName(), s.db), | ||||
| 		nowFuncOverride:   s.nowFuncOverride, | ||||
| 	} | ||||
| 
 | ||||
| 	s.values.Range(func(k, v interface{}) bool { | ||||
| 		db.values.Store(k, v) | ||||
| 		return true | ||||
| 	}) | ||||
| 
 | ||||
| 	if s.search == nil { | ||||
| 		db.search = &search{limit: -1, offset: -1} | ||||
| 	} else { | ||||
| 		db.search = s.search.clone() | ||||
| 	} | ||||
| 
 | ||||
| 	db.search.db = db | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| func (s *DB) print(v ...interface{}) { | ||||
| 	s.logger.Print(v...) | ||||
| } | ||||
| 
 | ||||
| func (s *DB) log(v ...interface{}) { | ||||
| 	if s != nil && s.logMode == detailedLogMode { | ||||
| 		s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { | ||||
| 	if s.logMode == detailedLogMode { | ||||
| 		s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										1444
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										1444
									
								
								main_test.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,579 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| type User struct { | ||||
| 	Id                int64 | ||||
| 	Age               int64 | ||||
| 	UserNum           Num | ||||
| 	Name              string `sql:"size:255"` | ||||
| 	Email             string | ||||
| 	Birthday          *time.Time    // Time
 | ||||
| 	CreatedAt         time.Time     // CreatedAt: Time of record is created, will be insert automatically
 | ||||
| 	UpdatedAt         time.Time     // UpdatedAt: Time of record is updated, will be updated automatically
 | ||||
| 	Emails            []Email       // Embedded structs
 | ||||
| 	BillingAddress    Address       // Embedded struct
 | ||||
| 	BillingAddressID  sql.NullInt64 // Embedded struct's foreign key
 | ||||
| 	ShippingAddress   Address       // Embedded struct
 | ||||
| 	ShippingAddressId int64         // Embedded struct's foreign key
 | ||||
| 	CreditCard        CreditCard | ||||
| 	Latitude          float64 | ||||
| 	Languages         []Language `gorm:"many2many:user_languages;"` | ||||
| 	CompanyID         *int | ||||
| 	Company           Company | ||||
| 	Role              Role | ||||
| 	Password          EncryptedData | ||||
| 	PasswordHash      []byte | ||||
| 	IgnoreMe          int64                 `sql:"-"` | ||||
| 	IgnoreStringSlice []string              `sql:"-"` | ||||
| 	Ignored           struct{ Name string } `sql:"-"` | ||||
| 	IgnoredPointer    *User                 `sql:"-"` | ||||
| } | ||||
| 
 | ||||
| type NotSoLongTableName struct { | ||||
| 	Id                int64 | ||||
| 	ReallyLongThingID int64 | ||||
| 	ReallyLongThing   ReallyLongTableNameToTestMySQLNameLengthLimit | ||||
| } | ||||
| 
 | ||||
| type ReallyLongTableNameToTestMySQLNameLengthLimit struct { | ||||
| 	Id int64 | ||||
| } | ||||
| 
 | ||||
| type ReallyLongThingThatReferencesShort struct { | ||||
| 	Id      int64 | ||||
| 	ShortID int64 | ||||
| 	Short   Short | ||||
| } | ||||
| 
 | ||||
| type Short struct { | ||||
| 	Id int64 | ||||
| } | ||||
| 
 | ||||
| type CreditCard struct { | ||||
| 	ID        int8 | ||||
| 	Number    string | ||||
| 	UserId    sql.NullInt64 | ||||
| 	CreatedAt time.Time `sql:"not null"` | ||||
| 	UpdatedAt time.Time | ||||
| 	DeletedAt *time.Time `sql:"column:deleted_time"` | ||||
| } | ||||
| 
 | ||||
| type Email struct { | ||||
| 	Id        int16 | ||||
| 	UserId    int | ||||
| 	Email     string `sql:"type:varchar(100);"` | ||||
| 	CreatedAt time.Time | ||||
| 	UpdatedAt time.Time | ||||
| } | ||||
| 
 | ||||
| type Address struct { | ||||
| 	ID        int | ||||
| 	Address1  string | ||||
| 	Address2  string | ||||
| 	Post      string | ||||
| 	CreatedAt time.Time | ||||
| 	UpdatedAt time.Time | ||||
| 	DeletedAt *time.Time | ||||
| } | ||||
| 
 | ||||
| type Language struct { | ||||
| 	gorm.Model | ||||
| 	Name  string | ||||
| 	Users []User `gorm:"many2many:user_languages;"` | ||||
| } | ||||
| 
 | ||||
| type Product struct { | ||||
| 	Id                    int64 | ||||
| 	Code                  string | ||||
| 	Price                 int64 | ||||
| 	CreatedAt             time.Time | ||||
| 	UpdatedAt             time.Time | ||||
| 	AfterFindCallTimes    int64 | ||||
| 	BeforeCreateCallTimes int64 | ||||
| 	AfterCreateCallTimes  int64 | ||||
| 	BeforeUpdateCallTimes int64 | ||||
| 	AfterUpdateCallTimes  int64 | ||||
| 	BeforeSaveCallTimes   int64 | ||||
| 	AfterSaveCallTimes    int64 | ||||
| 	BeforeDeleteCallTimes int64 | ||||
| 	AfterDeleteCallTimes  int64 | ||||
| } | ||||
| 
 | ||||
| type Company struct { | ||||
| 	Id    int64 | ||||
| 	Name  string | ||||
| 	Owner *User `sql:"-"` | ||||
| } | ||||
| 
 | ||||
| type Place struct { | ||||
| 	Id             int64 | ||||
| 	PlaceAddressID int | ||||
| 	PlaceAddress   *Address `gorm:"save_associations:false"` | ||||
| 	OwnerAddressID int | ||||
| 	OwnerAddress   *Address `gorm:"save_associations:true"` | ||||
| } | ||||
| 
 | ||||
| type EncryptedData []byte | ||||
| 
 | ||||
| func (data *EncryptedData) Scan(value interface{}) error { | ||||
| 	if b, ok := value.([]byte); ok { | ||||
| 		if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { | ||||
| 			return errors.New("Too short") | ||||
| 		} | ||||
| 
 | ||||
| 		*data = b[3:] | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return errors.New("Bytes expected") | ||||
| } | ||||
| 
 | ||||
| func (data EncryptedData) Value() (driver.Value, error) { | ||||
| 	if len(data) > 0 && data[0] == 'x' { | ||||
| 		//needed to test failures
 | ||||
| 		return nil, errors.New("Should not start with 'x'") | ||||
| 	} | ||||
| 
 | ||||
| 	//prepend asterisks
 | ||||
| 	return append([]byte("***"), data...), nil | ||||
| } | ||||
| 
 | ||||
| type Role struct { | ||||
| 	Name string `gorm:"size:256"` | ||||
| } | ||||
| 
 | ||||
| func (role *Role) Scan(value interface{}) error { | ||||
| 	if b, ok := value.([]uint8); ok { | ||||
| 		role.Name = string(b) | ||||
| 	} else { | ||||
| 		role.Name = value.(string) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (role Role) Value() (driver.Value, error) { | ||||
| 	return role.Name, nil | ||||
| } | ||||
| 
 | ||||
| func (role Role) IsAdmin() bool { | ||||
| 	return role.Name == "admin" | ||||
| } | ||||
| 
 | ||||
| type Num int64 | ||||
| 
 | ||||
| func (i *Num) Scan(src interface{}) error { | ||||
| 	switch s := src.(type) { | ||||
| 	case []byte: | ||||
| 		n, _ := strconv.Atoi(string(s)) | ||||
| 		*i = Num(n) | ||||
| 	case int64: | ||||
| 		*i = Num(s) | ||||
| 	default: | ||||
| 		return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type Animal struct { | ||||
| 	Counter    uint64    `gorm:"primary_key:yes"` | ||||
| 	Name       string    `sql:"DEFAULT:'galeone'"` | ||||
| 	From       string    //test reserved sql keyword as field name
 | ||||
| 	Age        time.Time `sql:"DEFAULT:current_timestamp"` | ||||
| 	unexported string    // unexported value
 | ||||
| 	CreatedAt  time.Time | ||||
| 	UpdatedAt  time.Time | ||||
| } | ||||
| 
 | ||||
| type JoinTable struct { | ||||
| 	From uint64 | ||||
| 	To   uint64 | ||||
| 	Time time.Time `sql:"default: null"` | ||||
| } | ||||
| 
 | ||||
| type Post struct { | ||||
| 	Id             int64 | ||||
| 	CategoryId     sql.NullInt64 | ||||
| 	MainCategoryId int64 | ||||
| 	Title          string | ||||
| 	Body           string | ||||
| 	Comments       []*Comment | ||||
| 	Category       Category | ||||
| 	MainCategory   Category | ||||
| } | ||||
| 
 | ||||
| type Category struct { | ||||
| 	gorm.Model | ||||
| 	Name string | ||||
| 
 | ||||
| 	Categories []Category | ||||
| 	CategoryID *uint | ||||
| } | ||||
| 
 | ||||
| type Comment struct { | ||||
| 	gorm.Model | ||||
| 	PostId  int64 | ||||
| 	Content string | ||||
| 	Post    Post | ||||
| } | ||||
| 
 | ||||
| // Scanner
 | ||||
| type NullValue struct { | ||||
| 	Id      int64 | ||||
| 	Name    sql.NullString  `sql:"not null"` | ||||
| 	Gender  *sql.NullString `sql:"not null"` | ||||
| 	Age     sql.NullInt64 | ||||
| 	Male    sql.NullBool | ||||
| 	Height  sql.NullFloat64 | ||||
| 	AddedAt NullTime | ||||
| } | ||||
| 
 | ||||
| type NullTime struct { | ||||
| 	Time  time.Time | ||||
| 	Valid bool | ||||
| } | ||||
| 
 | ||||
| func (nt *NullTime) Scan(value interface{}) error { | ||||
| 	if value == nil { | ||||
| 		nt.Valid = false | ||||
| 		return nil | ||||
| 	} | ||||
| 	nt.Time, nt.Valid = value.(time.Time), true | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (nt NullTime) Value() (driver.Value, error) { | ||||
| 	if !nt.Valid { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	return nt.Time, nil | ||||
| } | ||||
| 
 | ||||
| func getPreparedUser(name string, role string) *User { | ||||
| 	var company Company | ||||
| 	DB.Where(Company{Name: role}).FirstOrCreate(&company) | ||||
| 
 | ||||
| 	return &User{ | ||||
| 		Name:            name, | ||||
| 		Age:             20, | ||||
| 		Role:            Role{role}, | ||||
| 		BillingAddress:  Address{Address1: fmt.Sprintf("Billing Address %v", name)}, | ||||
| 		ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, | ||||
| 		CreditCard:      CreditCard{Number: fmt.Sprintf("123456%v", name)}, | ||||
| 		Emails: []Email{ | ||||
| 			{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, | ||||
| 		}, | ||||
| 		Company: company, | ||||
| 		Languages: []Language{ | ||||
| 			{Name: fmt.Sprintf("lang_1_%v", name)}, | ||||
| 			{Name: fmt.Sprintf("lang_2_%v", name)}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func runMigration() { | ||||
| 	if err := DB.DropTableIfExists(&User{}).Error; err != nil { | ||||
| 		fmt.Printf("Got error when try to delete table users, %+v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, table := range []string{"animals", "user_languages"} { | ||||
| 		DB.Exec(fmt.Sprintf("drop table %v;", table)) | ||||
| 	} | ||||
| 
 | ||||
| 	values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} | ||||
| 	for _, value := range values { | ||||
| 		DB.DropTable(value) | ||||
| 	} | ||||
| 	if err := DB.AutoMigrate(values...).Error; err != nil { | ||||
| 		panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestIndexes(t *testing.T) { | ||||
| 	if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { | ||||
| 		t.Errorf("Got error when tried to create index: %+v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	scope := DB.NewScope(&Email{}) | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { | ||||
| 		t.Errorf("Email should have index idx_email_email") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { | ||||
| 		t.Errorf("Got error when tried to remove index: %+v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { | ||||
| 		t.Errorf("Email's index idx_email_email should be deleted") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { | ||||
| 		t.Errorf("Got error when tried to create index: %+v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | ||||
| 		t.Errorf("Email should have index idx_email_email_and_user_id") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { | ||||
| 		t.Errorf("Got error when tried to remove index: %+v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | ||||
| 		t.Errorf("Email's index idx_email_email_and_user_id should be deleted") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { | ||||
| 		t.Errorf("Got error when tried to create index: %+v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | ||||
| 		t.Errorf("Email should have index idx_email_email_and_user_id") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { | ||||
| 		t.Errorf("Should get to create duplicate record when having unique index") | ||||
| 	} | ||||
| 
 | ||||
| 	var user = User{Name: "sample_user"} | ||||
| 	DB.Save(&user) | ||||
| 	if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { | ||||
| 		t.Errorf("Should get no error when append two emails for user") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { | ||||
| 		t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { | ||||
| 		t.Errorf("Got error when tried to remove index: %+v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { | ||||
| 		t.Errorf("Email's index idx_email_email_and_user_id should be deleted") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { | ||||
| 		t.Errorf("Should be able to create duplicated emails after remove unique index") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type EmailWithIdx struct { | ||||
| 	Id           int64 | ||||
| 	UserId       int64 | ||||
| 	Email        string     `sql:"index:idx_email_agent"` | ||||
| 	UserAgent    string     `sql:"index:idx_email_agent"` | ||||
| 	RegisteredAt *time.Time `sql:"unique_index"` | ||||
| 	CreatedAt    time.Time | ||||
| 	UpdatedAt    time.Time | ||||
| } | ||||
| 
 | ||||
| func TestAutoMigration(t *testing.T) { | ||||
| 	DB.AutoMigrate(&Address{}) | ||||
| 	DB.DropTable(&EmailWithIdx{}) | ||||
| 	if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { | ||||
| 		t.Errorf("Auto Migrate should not raise any error") | ||||
| 	} | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) | ||||
| 
 | ||||
| 	scope := DB.NewScope(&EmailWithIdx{}) | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	var bigemail EmailWithIdx | ||||
| 	DB.First(&bigemail, "user_agent = ?", "pc") | ||||
| 	if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { | ||||
| 		t.Error("Big Emails should be saved and fetched correctly") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateAndAutomigrateTransaction(t *testing.T) { | ||||
| 	tx := DB.Begin() | ||||
| 
 | ||||
| 	func() { | ||||
| 		type Bar struct { | ||||
| 			ID uint | ||||
| 		} | ||||
| 		DB.DropTableIfExists(&Bar{}) | ||||
| 
 | ||||
| 		if ok := DB.HasTable("bars"); ok { | ||||
| 			t.Errorf("Table should not exist, but does") | ||||
| 		} | ||||
| 
 | ||||
| 		if ok := tx.HasTable("bars"); ok { | ||||
| 			t.Errorf("Table should not exist, but does") | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	func() { | ||||
| 		type Bar struct { | ||||
| 			Name string | ||||
| 		} | ||||
| 		err := tx.CreateTable(&Bar{}).Error | ||||
| 
 | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Should have been able to create the table, but couldn't: %s", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if ok := tx.HasTable(&Bar{}); !ok { | ||||
| 			t.Errorf("The transaction should be able to see the table") | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	func() { | ||||
| 		type Bar struct { | ||||
| 			Stuff string | ||||
| 		} | ||||
| 
 | ||||
| 		err := tx.AutoMigrate(&Bar{}).Error | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Should have been able to alter the table, but couldn't") | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	tx.Rollback() | ||||
| } | ||||
| 
 | ||||
| type MultipleIndexes struct { | ||||
| 	ID     int64 | ||||
| 	UserID int64  `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` | ||||
| 	Name   string `sql:"unique_index:uix_multipleindexes_user_name"` | ||||
| 	Email  string `sql:"unique_index:,uix_multipleindexes_user_email"` | ||||
| 	Other  string `sql:"index:,idx_multipleindexes_user_other"` | ||||
| } | ||||
| 
 | ||||
| func TestMultipleIndexes(t *testing.T) { | ||||
| 	if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { | ||||
| 		fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.AutoMigrate(&MultipleIndexes{}) | ||||
| 	if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { | ||||
| 		t.Errorf("Auto Migrate should not raise any error") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) | ||||
| 
 | ||||
| 	scope := DB.NewScope(&MultipleIndexes{}) | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { | ||||
| 		t.Errorf("Failed to create index") | ||||
| 	} | ||||
| 
 | ||||
| 	var mutipleIndexes MultipleIndexes | ||||
| 	DB.First(&mutipleIndexes, "name = ?", "jinzhu") | ||||
| 	if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { | ||||
| 		t.Error("MutipleIndexes should be saved and fetched correctly") | ||||
| 	} | ||||
| 
 | ||||
| 	// Check unique constraints
 | ||||
| 	if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { | ||||
| 		t.Error("MultipleIndexes unique index failed") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { | ||||
| 		t.Error("MultipleIndexes unique index failed") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { | ||||
| 		t.Error("MultipleIndexes unique index failed") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { | ||||
| 		t.Error("MultipleIndexes unique index failed") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestModifyColumnType(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { | ||||
| 		t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") | ||||
| 	} | ||||
| 
 | ||||
| 	type ModifyColumnType struct { | ||||
| 		gorm.Model | ||||
| 		Name1 string `gorm:"length:100"` | ||||
| 		Name2 string `gorm:"length:200"` | ||||
| 	} | ||||
| 	DB.DropTable(&ModifyColumnType{}) | ||||
| 	DB.CreateTable(&ModifyColumnType{}) | ||||
| 
 | ||||
| 	name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") | ||||
| 	name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) | ||||
| 
 | ||||
| 	if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when ModifyColumn, but got %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestIndexWithPrefixLength(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { | ||||
| 		t.Skip("Skipping this because only mysql support setting an index prefix length") | ||||
| 	} | ||||
| 
 | ||||
| 	type IndexWithPrefix struct { | ||||
| 		gorm.Model | ||||
| 		Name        string | ||||
| 		Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` | ||||
| 	} | ||||
| 	type IndexesWithPrefix struct { | ||||
| 		gorm.Model | ||||
| 		Name         string | ||||
| 		Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` | ||||
| 		Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` | ||||
| 	} | ||||
| 	type IndexesWithPrefixAndWithoutPrefix struct { | ||||
| 		gorm.Model | ||||
| 		Name        string `gorm:"index:idx_index_with_prefixes_length"` | ||||
| 		Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` | ||||
| 	} | ||||
| 	tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} | ||||
| 	for _, table := range tables { | ||||
| 		scope := DB.NewScope(table) | ||||
| 		tableName := scope.TableName() | ||||
| 		t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { | ||||
| 			if err := DB.DropTableIfExists(table).Error; err != nil { | ||||
| 				t.Errorf("Failed to drop %s table: %v", tableName, err) | ||||
| 			} | ||||
| 			if err := DB.CreateTable(table).Error; err != nil { | ||||
| 				t.Errorf("Failed to create %s table: %v", tableName, err) | ||||
| 			} | ||||
| 			if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { | ||||
| 				t.Errorf("Failed to create %s table index:", tableName) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										14
									
								
								model.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								model.go
									
									
									
									
									
								
							| @ -1,14 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import "time" | ||||
| 
 | ||||
| // Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
 | ||||
| //    type User struct {
 | ||||
| //      gorm.Model
 | ||||
| //    }
 | ||||
| type Model struct { | ||||
| 	ID        uint `gorm:"primary_key"` | ||||
| 	CreatedAt time.Time | ||||
| 	UpdatedAt time.Time | ||||
| 	DeletedAt *time.Time `sql:"index"` | ||||
| } | ||||
							
								
								
									
										671
									
								
								model_struct.go
									
									
									
									
									
								
							
							
						
						
									
										671
									
								
								model_struct.go
									
									
									
									
									
								
							| @ -1,671 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"go/ast" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/inflection" | ||||
| ) | ||||
| 
 | ||||
| // DefaultTableNameHandler default table name handler
 | ||||
| var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { | ||||
| 	return defaultTableName | ||||
| } | ||||
| 
 | ||||
| // lock for mutating global cached model metadata
 | ||||
| var structsLock sync.Mutex | ||||
| 
 | ||||
| // global cache of model metadata
 | ||||
| var modelStructsMap sync.Map | ||||
| 
 | ||||
| // ModelStruct model definition
 | ||||
| type ModelStruct struct { | ||||
| 	PrimaryFields []*StructField | ||||
| 	StructFields  []*StructField | ||||
| 	ModelType     reflect.Type | ||||
| 
 | ||||
| 	defaultTableName string | ||||
| 	l                sync.Mutex | ||||
| } | ||||
| 
 | ||||
| // TableName returns model's table name
 | ||||
| func (s *ModelStruct) TableName(db *DB) string { | ||||
| 	s.l.Lock() | ||||
| 	defer s.l.Unlock() | ||||
| 
 | ||||
| 	if s.defaultTableName == "" && db != nil && s.ModelType != nil { | ||||
| 		// Set default table name
 | ||||
| 		if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { | ||||
| 			s.defaultTableName = tabler.TableName() | ||||
| 		} else { | ||||
| 			tableName := ToTableName(s.ModelType.Name()) | ||||
| 			db.parent.RLock() | ||||
| 			if db == nil || (db.parent != nil && !db.parent.singularTable) { | ||||
| 				tableName = inflection.Plural(tableName) | ||||
| 			} | ||||
| 			db.parent.RUnlock() | ||||
| 			s.defaultTableName = tableName | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return DefaultTableNameHandler(db, s.defaultTableName) | ||||
| } | ||||
| 
 | ||||
| // StructField model field's struct definition
 | ||||
| type StructField struct { | ||||
| 	DBName          string | ||||
| 	Name            string | ||||
| 	Names           []string | ||||
| 	IsPrimaryKey    bool | ||||
| 	IsNormal        bool | ||||
| 	IsIgnored       bool | ||||
| 	IsScanner       bool | ||||
| 	HasDefaultValue bool | ||||
| 	Tag             reflect.StructTag | ||||
| 	TagSettings     map[string]string | ||||
| 	Struct          reflect.StructField | ||||
| 	IsForeignKey    bool | ||||
| 	Relationship    *Relationship | ||||
| 
 | ||||
| 	tagSettingsLock sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| // TagSettingsSet Sets a tag in the tag settings map
 | ||||
| func (sf *StructField) TagSettingsSet(key, val string) { | ||||
| 	sf.tagSettingsLock.Lock() | ||||
| 	defer sf.tagSettingsLock.Unlock() | ||||
| 	sf.TagSettings[key] = val | ||||
| } | ||||
| 
 | ||||
| // TagSettingsGet returns a tag from the tag settings
 | ||||
| func (sf *StructField) TagSettingsGet(key string) (string, bool) { | ||||
| 	sf.tagSettingsLock.RLock() | ||||
| 	defer sf.tagSettingsLock.RUnlock() | ||||
| 	val, ok := sf.TagSettings[key] | ||||
| 	return val, ok | ||||
| } | ||||
| 
 | ||||
| // TagSettingsDelete deletes a tag
 | ||||
| func (sf *StructField) TagSettingsDelete(key string) { | ||||
| 	sf.tagSettingsLock.Lock() | ||||
| 	defer sf.tagSettingsLock.Unlock() | ||||
| 	delete(sf.TagSettings, key) | ||||
| } | ||||
| 
 | ||||
| func (sf *StructField) clone() *StructField { | ||||
| 	clone := &StructField{ | ||||
| 		DBName:          sf.DBName, | ||||
| 		Name:            sf.Name, | ||||
| 		Names:           sf.Names, | ||||
| 		IsPrimaryKey:    sf.IsPrimaryKey, | ||||
| 		IsNormal:        sf.IsNormal, | ||||
| 		IsIgnored:       sf.IsIgnored, | ||||
| 		IsScanner:       sf.IsScanner, | ||||
| 		HasDefaultValue: sf.HasDefaultValue, | ||||
| 		Tag:             sf.Tag, | ||||
| 		TagSettings:     map[string]string{}, | ||||
| 		Struct:          sf.Struct, | ||||
| 		IsForeignKey:    sf.IsForeignKey, | ||||
| 	} | ||||
| 
 | ||||
| 	if sf.Relationship != nil { | ||||
| 		relationship := *sf.Relationship | ||||
| 		clone.Relationship = &relationship | ||||
| 	} | ||||
| 
 | ||||
| 	// copy the struct field tagSettings, they should be read-locked while they are copied
 | ||||
| 	sf.tagSettingsLock.Lock() | ||||
| 	defer sf.tagSettingsLock.Unlock() | ||||
| 	for key, value := range sf.TagSettings { | ||||
| 		clone.TagSettings[key] = value | ||||
| 	} | ||||
| 
 | ||||
| 	return clone | ||||
| } | ||||
| 
 | ||||
| // Relationship described the relationship between models
 | ||||
| type Relationship struct { | ||||
| 	Kind                         string | ||||
| 	PolymorphicType              string | ||||
| 	PolymorphicDBName            string | ||||
| 	PolymorphicValue             string | ||||
| 	ForeignFieldNames            []string | ||||
| 	ForeignDBNames               []string | ||||
| 	AssociationForeignFieldNames []string | ||||
| 	AssociationForeignDBNames    []string | ||||
| 	JoinTableHandler             JoinTableHandlerInterface | ||||
| } | ||||
| 
 | ||||
| func getForeignField(column string, fields []*StructField) *StructField { | ||||
| 	for _, field := range fields { | ||||
| 		if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { | ||||
| 			return field | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // GetModelStruct get value's model struct, relationships based on struct and tag definition
 | ||||
| func (scope *Scope) GetModelStruct() *ModelStruct { | ||||
| 	var modelStruct ModelStruct | ||||
| 	// Scope value can't be nil
 | ||||
| 	if scope.Value == nil { | ||||
| 		return &modelStruct | ||||
| 	} | ||||
| 
 | ||||
| 	reflectType := reflect.ValueOf(scope.Value).Type() | ||||
| 	for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { | ||||
| 		reflectType = reflectType.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	// Scope value need to be a struct
 | ||||
| 	if reflectType.Kind() != reflect.Struct { | ||||
| 		return &modelStruct | ||||
| 	} | ||||
| 
 | ||||
| 	// Get Cached model struct
 | ||||
| 	isSingularTable := false | ||||
| 	if scope.db != nil && scope.db.parent != nil { | ||||
| 		scope.db.parent.RLock() | ||||
| 		isSingularTable = scope.db.parent.singularTable | ||||
| 		scope.db.parent.RUnlock() | ||||
| 	} | ||||
| 
 | ||||
| 	hashKey := struct { | ||||
| 		singularTable bool | ||||
| 		reflectType   reflect.Type | ||||
| 	}{isSingularTable, reflectType} | ||||
| 	if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { | ||||
| 		return value.(*ModelStruct) | ||||
| 	} | ||||
| 
 | ||||
| 	modelStruct.ModelType = reflectType | ||||
| 
 | ||||
| 	// Get all fields
 | ||||
| 	for i := 0; i < reflectType.NumField(); i++ { | ||||
| 		if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { | ||||
| 			field := &StructField{ | ||||
| 				Struct:      fieldStruct, | ||||
| 				Name:        fieldStruct.Name, | ||||
| 				Names:       []string{fieldStruct.Name}, | ||||
| 				Tag:         fieldStruct.Tag, | ||||
| 				TagSettings: parseTagSetting(fieldStruct.Tag), | ||||
| 			} | ||||
| 
 | ||||
| 			// is ignored field
 | ||||
| 			if _, ok := field.TagSettingsGet("-"); ok { | ||||
| 				field.IsIgnored = true | ||||
| 			} else { | ||||
| 				if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { | ||||
| 					field.IsPrimaryKey = true | ||||
| 					modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) | ||||
| 				} | ||||
| 
 | ||||
| 				if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { | ||||
| 					field.HasDefaultValue = true | ||||
| 				} | ||||
| 
 | ||||
| 				if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { | ||||
| 					field.HasDefaultValue = true | ||||
| 				} | ||||
| 
 | ||||
| 				indirectType := fieldStruct.Type | ||||
| 				for indirectType.Kind() == reflect.Ptr { | ||||
| 					indirectType = indirectType.Elem() | ||||
| 				} | ||||
| 
 | ||||
| 				fieldValue := reflect.New(indirectType).Interface() | ||||
| 				if _, isScanner := fieldValue.(sql.Scanner); isScanner { | ||||
| 					// is scanner
 | ||||
| 					field.IsScanner, field.IsNormal = true, true | ||||
| 					if indirectType.Kind() == reflect.Struct { | ||||
| 						for i := 0; i < indirectType.NumField(); i++ { | ||||
| 							for key, value := range parseTagSetting(indirectType.Field(i).Tag) { | ||||
| 								if _, ok := field.TagSettingsGet(key); !ok { | ||||
| 									field.TagSettingsSet(key, value) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} else if _, isTime := fieldValue.(*time.Time); isTime { | ||||
| 					// is time
 | ||||
| 					field.IsNormal = true | ||||
| 				} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { | ||||
| 					// is embedded struct
 | ||||
| 					for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { | ||||
| 						subField = subField.clone() | ||||
| 						subField.Names = append([]string{fieldStruct.Name}, subField.Names...) | ||||
| 						if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { | ||||
| 							subField.DBName = prefix + subField.DBName | ||||
| 						} | ||||
| 
 | ||||
| 						if subField.IsPrimaryKey { | ||||
| 							if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { | ||||
| 								modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) | ||||
| 							} else { | ||||
| 								subField.IsPrimaryKey = false | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { | ||||
| 							if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { | ||||
| 								newJoinTableHandler := &JoinTableHandler{} | ||||
| 								newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) | ||||
| 								subField.Relationship.JoinTableHandler = newJoinTableHandler | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						modelStruct.StructFields = append(modelStruct.StructFields, subField) | ||||
| 					} | ||||
| 					continue | ||||
| 				} else { | ||||
| 					// build relationships
 | ||||
| 					switch indirectType.Kind() { | ||||
| 					case reflect.Slice: | ||||
| 						defer func(field *StructField) { | ||||
| 							var ( | ||||
| 								relationship           = &Relationship{} | ||||
| 								toScope                = scope.New(reflect.New(field.Struct.Type).Interface()) | ||||
| 								foreignKeys            []string | ||||
| 								associationForeignKeys []string | ||||
| 								elemType               = field.Struct.Type | ||||
| 							) | ||||
| 
 | ||||
| 							if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { | ||||
| 								foreignKeys = strings.Split(foreignKey, ",") | ||||
| 							} | ||||
| 
 | ||||
| 							if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { | ||||
| 								associationForeignKeys = strings.Split(foreignKey, ",") | ||||
| 							} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { | ||||
| 								associationForeignKeys = strings.Split(foreignKey, ",") | ||||
| 							} | ||||
| 
 | ||||
| 							for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { | ||||
| 								elemType = elemType.Elem() | ||||
| 							} | ||||
| 
 | ||||
| 							if elemType.Kind() == reflect.Struct { | ||||
| 								if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { | ||||
| 									relationship.Kind = "many_to_many" | ||||
| 
 | ||||
| 									{ // Foreign Keys for Source
 | ||||
| 										joinTableDBNames := []string{} | ||||
| 
 | ||||
| 										if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { | ||||
| 											joinTableDBNames = strings.Split(foreignKey, ",") | ||||
| 										} | ||||
| 
 | ||||
| 										// if no foreign keys defined with tag
 | ||||
| 										if len(foreignKeys) == 0 { | ||||
| 											for _, field := range modelStruct.PrimaryFields { | ||||
| 												foreignKeys = append(foreignKeys, field.DBName) | ||||
| 											} | ||||
| 										} | ||||
| 
 | ||||
| 										for idx, foreignKey := range foreignKeys { | ||||
| 											if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { | ||||
| 												// source foreign keys (db names)
 | ||||
| 												relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) | ||||
| 
 | ||||
| 												// setup join table foreign keys for source
 | ||||
| 												if len(joinTableDBNames) > idx { | ||||
| 													// if defined join table's foreign key
 | ||||
| 													relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) | ||||
| 												} else { | ||||
| 													defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName | ||||
| 													relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) | ||||
| 												} | ||||
| 											} | ||||
| 										} | ||||
| 									} | ||||
| 
 | ||||
| 									{ // Foreign Keys for Association (Destination)
 | ||||
| 										associationJoinTableDBNames := []string{} | ||||
| 
 | ||||
| 										if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { | ||||
| 											associationJoinTableDBNames = strings.Split(foreignKey, ",") | ||||
| 										} | ||||
| 
 | ||||
| 										// if no association foreign keys defined with tag
 | ||||
| 										if len(associationForeignKeys) == 0 { | ||||
| 											for _, field := range toScope.PrimaryFields() { | ||||
| 												associationForeignKeys = append(associationForeignKeys, field.DBName) | ||||
| 											} | ||||
| 										} | ||||
| 
 | ||||
| 										for idx, name := range associationForeignKeys { | ||||
| 											if field, ok := toScope.FieldByName(name); ok { | ||||
| 												// association foreign keys (db names)
 | ||||
| 												relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) | ||||
| 
 | ||||
| 												// setup join table foreign keys for association
 | ||||
| 												if len(associationJoinTableDBNames) > idx { | ||||
| 													relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) | ||||
| 												} else { | ||||
| 													// join table foreign keys for association
 | ||||
| 													joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName | ||||
| 													relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) | ||||
| 												} | ||||
| 											} | ||||
| 										} | ||||
| 									} | ||||
| 
 | ||||
| 									joinTableHandler := JoinTableHandler{} | ||||
| 									joinTableHandler.Setup(relationship, many2many, reflectType, elemType) | ||||
| 									relationship.JoinTableHandler = &joinTableHandler | ||||
| 									field.Relationship = relationship | ||||
| 								} else { | ||||
| 									// User has many comments, associationType is User, comment use UserID as foreign key
 | ||||
| 									var associationType = reflectType.Name() | ||||
| 									var toFields = toScope.GetStructFields() | ||||
| 									relationship.Kind = "has_many" | ||||
| 
 | ||||
| 									if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { | ||||
| 										// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
 | ||||
| 										// Toy use OwnerID, OwnerType ('dogs') as foreign key
 | ||||
| 										if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { | ||||
| 											associationType = polymorphic | ||||
| 											relationship.PolymorphicType = polymorphicType.Name | ||||
| 											relationship.PolymorphicDBName = polymorphicType.DBName | ||||
| 											// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
 | ||||
| 											if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { | ||||
| 												relationship.PolymorphicValue = value | ||||
| 											} else { | ||||
| 												relationship.PolymorphicValue = scope.TableName() | ||||
| 											} | ||||
| 											polymorphicType.IsForeignKey = true | ||||
| 										} | ||||
| 									} | ||||
| 
 | ||||
| 									// if no foreign keys defined with tag
 | ||||
| 									if len(foreignKeys) == 0 { | ||||
| 										// if no association foreign keys defined with tag
 | ||||
| 										if len(associationForeignKeys) == 0 { | ||||
| 											for _, field := range modelStruct.PrimaryFields { | ||||
| 												foreignKeys = append(foreignKeys, associationType+field.Name) | ||||
| 												associationForeignKeys = append(associationForeignKeys, field.Name) | ||||
| 											} | ||||
| 										} else { | ||||
| 											// generate foreign keys from defined association foreign keys
 | ||||
| 											for _, scopeFieldName := range associationForeignKeys { | ||||
| 												if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { | ||||
| 													foreignKeys = append(foreignKeys, associationType+foreignField.Name) | ||||
| 													associationForeignKeys = append(associationForeignKeys, foreignField.Name) | ||||
| 												} | ||||
| 											} | ||||
| 										} | ||||
| 									} else { | ||||
| 										// generate association foreign keys from foreign keys
 | ||||
| 										if len(associationForeignKeys) == 0 { | ||||
| 											for _, foreignKey := range foreignKeys { | ||||
| 												if strings.HasPrefix(foreignKey, associationType) { | ||||
| 													associationForeignKey := strings.TrimPrefix(foreignKey, associationType) | ||||
| 													if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { | ||||
| 														associationForeignKeys = append(associationForeignKeys, associationForeignKey) | ||||
| 													} | ||||
| 												} | ||||
| 											} | ||||
| 											if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { | ||||
| 												associationForeignKeys = []string{scope.PrimaryKey()} | ||||
| 											} | ||||
| 										} else if len(foreignKeys) != len(associationForeignKeys) { | ||||
| 											scope.Err(errors.New("invalid foreign keys, should have same length")) | ||||
| 											return | ||||
| 										} | ||||
| 									} | ||||
| 
 | ||||
| 									for idx, foreignKey := range foreignKeys { | ||||
| 										if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { | ||||
| 											if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { | ||||
| 												// mark field as foreignkey, use global lock to avoid race
 | ||||
| 												structsLock.Lock() | ||||
| 												foreignField.IsForeignKey = true | ||||
| 												structsLock.Unlock() | ||||
| 
 | ||||
| 												// association foreign keys
 | ||||
| 												relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) | ||||
| 												relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) | ||||
| 
 | ||||
| 												// association foreign keys
 | ||||
| 												relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | ||||
| 												relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | ||||
| 											} | ||||
| 										} | ||||
| 									} | ||||
| 
 | ||||
| 									if len(relationship.ForeignFieldNames) != 0 { | ||||
| 										field.Relationship = relationship | ||||
| 									} | ||||
| 								} | ||||
| 							} else { | ||||
| 								field.IsNormal = true | ||||
| 							} | ||||
| 						}(field) | ||||
| 					case reflect.Struct: | ||||
| 						defer func(field *StructField) { | ||||
| 							var ( | ||||
| 								// user has one profile, associationType is User, profile use UserID as foreign key
 | ||||
| 								// user belongs to profile, associationType is Profile, user use ProfileID as foreign key
 | ||||
| 								associationType           = reflectType.Name() | ||||
| 								relationship              = &Relationship{} | ||||
| 								toScope                   = scope.New(reflect.New(field.Struct.Type).Interface()) | ||||
| 								toFields                  = toScope.GetStructFields() | ||||
| 								tagForeignKeys            []string | ||||
| 								tagAssociationForeignKeys []string | ||||
| 							) | ||||
| 
 | ||||
| 							if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { | ||||
| 								tagForeignKeys = strings.Split(foreignKey, ",") | ||||
| 							} | ||||
| 
 | ||||
| 							if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { | ||||
| 								tagAssociationForeignKeys = strings.Split(foreignKey, ",") | ||||
| 							} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { | ||||
| 								tagAssociationForeignKeys = strings.Split(foreignKey, ",") | ||||
| 							} | ||||
| 
 | ||||
| 							if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { | ||||
| 								// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
 | ||||
| 								// Toy use OwnerID, OwnerType ('cats') as foreign key
 | ||||
| 								if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { | ||||
| 									associationType = polymorphic | ||||
| 									relationship.PolymorphicType = polymorphicType.Name | ||||
| 									relationship.PolymorphicDBName = polymorphicType.DBName | ||||
| 									// if Cat has several different types of toys set name for each (instead of default 'cats')
 | ||||
| 									if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { | ||||
| 										relationship.PolymorphicValue = value | ||||
| 									} else { | ||||
| 										relationship.PolymorphicValue = scope.TableName() | ||||
| 									} | ||||
| 									polymorphicType.IsForeignKey = true | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							// Has One
 | ||||
| 							{ | ||||
| 								var foreignKeys = tagForeignKeys | ||||
| 								var associationForeignKeys = tagAssociationForeignKeys | ||||
| 								// if no foreign keys defined with tag
 | ||||
| 								if len(foreignKeys) == 0 { | ||||
| 									// if no association foreign keys defined with tag
 | ||||
| 									if len(associationForeignKeys) == 0 { | ||||
| 										for _, primaryField := range modelStruct.PrimaryFields { | ||||
| 											foreignKeys = append(foreignKeys, associationType+primaryField.Name) | ||||
| 											associationForeignKeys = append(associationForeignKeys, primaryField.Name) | ||||
| 										} | ||||
| 									} else { | ||||
| 										// generate foreign keys form association foreign keys
 | ||||
| 										for _, associationForeignKey := range tagAssociationForeignKeys { | ||||
| 											if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { | ||||
| 												foreignKeys = append(foreignKeys, associationType+foreignField.Name) | ||||
| 												associationForeignKeys = append(associationForeignKeys, foreignField.Name) | ||||
| 											} | ||||
| 										} | ||||
| 									} | ||||
| 								} else { | ||||
| 									// generate association foreign keys from foreign keys
 | ||||
| 									if len(associationForeignKeys) == 0 { | ||||
| 										for _, foreignKey := range foreignKeys { | ||||
| 											if strings.HasPrefix(foreignKey, associationType) { | ||||
| 												associationForeignKey := strings.TrimPrefix(foreignKey, associationType) | ||||
| 												if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { | ||||
| 													associationForeignKeys = append(associationForeignKeys, associationForeignKey) | ||||
| 												} | ||||
| 											} | ||||
| 										} | ||||
| 										if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { | ||||
| 											associationForeignKeys = []string{scope.PrimaryKey()} | ||||
| 										} | ||||
| 									} else if len(foreignKeys) != len(associationForeignKeys) { | ||||
| 										scope.Err(errors.New("invalid foreign keys, should have same length")) | ||||
| 										return | ||||
| 									} | ||||
| 								} | ||||
| 
 | ||||
| 								for idx, foreignKey := range foreignKeys { | ||||
| 									if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { | ||||
| 										if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { | ||||
| 											// mark field as foreignkey, use global lock to avoid race
 | ||||
| 											structsLock.Lock() | ||||
| 											foreignField.IsForeignKey = true | ||||
| 											structsLock.Unlock() | ||||
| 
 | ||||
| 											// association foreign keys
 | ||||
| 											relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) | ||||
| 											relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) | ||||
| 
 | ||||
| 											// association foreign keys
 | ||||
| 											relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | ||||
| 											relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | ||||
| 										} | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							if len(relationship.ForeignFieldNames) != 0 { | ||||
| 								relationship.Kind = "has_one" | ||||
| 								field.Relationship = relationship | ||||
| 							} else { | ||||
| 								var foreignKeys = tagForeignKeys | ||||
| 								var associationForeignKeys = tagAssociationForeignKeys | ||||
| 
 | ||||
| 								if len(foreignKeys) == 0 { | ||||
| 									// generate foreign keys & association foreign keys
 | ||||
| 									if len(associationForeignKeys) == 0 { | ||||
| 										for _, primaryField := range toScope.PrimaryFields() { | ||||
| 											foreignKeys = append(foreignKeys, field.Name+primaryField.Name) | ||||
| 											associationForeignKeys = append(associationForeignKeys, primaryField.Name) | ||||
| 										} | ||||
| 									} else { | ||||
| 										// generate foreign keys with association foreign keys
 | ||||
| 										for _, associationForeignKey := range associationForeignKeys { | ||||
| 											if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { | ||||
| 												foreignKeys = append(foreignKeys, field.Name+foreignField.Name) | ||||
| 												associationForeignKeys = append(associationForeignKeys, foreignField.Name) | ||||
| 											} | ||||
| 										} | ||||
| 									} | ||||
| 								} else { | ||||
| 									// generate foreign keys & association foreign keys
 | ||||
| 									if len(associationForeignKeys) == 0 { | ||||
| 										for _, foreignKey := range foreignKeys { | ||||
| 											if strings.HasPrefix(foreignKey, field.Name) { | ||||
| 												associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) | ||||
| 												if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { | ||||
| 													associationForeignKeys = append(associationForeignKeys, associationForeignKey) | ||||
| 												} | ||||
| 											} | ||||
| 										} | ||||
| 										if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { | ||||
| 											associationForeignKeys = []string{toScope.PrimaryKey()} | ||||
| 										} | ||||
| 									} else if len(foreignKeys) != len(associationForeignKeys) { | ||||
| 										scope.Err(errors.New("invalid foreign keys, should have same length")) | ||||
| 										return | ||||
| 									} | ||||
| 								} | ||||
| 
 | ||||
| 								for idx, foreignKey := range foreignKeys { | ||||
| 									if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { | ||||
| 										if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { | ||||
| 											// mark field as foreignkey, use global lock to avoid race
 | ||||
| 											structsLock.Lock() | ||||
| 											foreignField.IsForeignKey = true | ||||
| 											structsLock.Unlock() | ||||
| 
 | ||||
| 											// association foreign keys
 | ||||
| 											relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) | ||||
| 											relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) | ||||
| 
 | ||||
| 											// source foreign keys
 | ||||
| 											relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) | ||||
| 											relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) | ||||
| 										} | ||||
| 									} | ||||
| 								} | ||||
| 
 | ||||
| 								if len(relationship.ForeignFieldNames) != 0 { | ||||
| 									relationship.Kind = "belongs_to" | ||||
| 									field.Relationship = relationship | ||||
| 								} | ||||
| 							} | ||||
| 						}(field) | ||||
| 					default: | ||||
| 						field.IsNormal = true | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// Even it is ignored, also possible to decode db value into the field
 | ||||
| 			if value, ok := field.TagSettingsGet("COLUMN"); ok { | ||||
| 				field.DBName = value | ||||
| 			} else { | ||||
| 				field.DBName = ToColumnName(fieldStruct.Name) | ||||
| 			} | ||||
| 
 | ||||
| 			modelStruct.StructFields = append(modelStruct.StructFields, field) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len(modelStruct.PrimaryFields) == 0 { | ||||
| 		if field := getForeignField("id", modelStruct.StructFields); field != nil { | ||||
| 			field.IsPrimaryKey = true | ||||
| 			modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	modelStructsMap.Store(hashKey, &modelStruct) | ||||
| 
 | ||||
| 	return &modelStruct | ||||
| } | ||||
| 
 | ||||
| // GetStructFields get model's field structs
 | ||||
| func (scope *Scope) GetStructFields() (fields []*StructField) { | ||||
| 	return scope.GetModelStruct().StructFields | ||||
| } | ||||
| 
 | ||||
| func parseTagSetting(tags reflect.StructTag) map[string]string { | ||||
| 	setting := map[string]string{} | ||||
| 	for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { | ||||
| 		if str == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		tags := strings.Split(str, ";") | ||||
| 		for _, value := range tags { | ||||
| 			v := strings.Split(value, ":") | ||||
| 			k := strings.TrimSpace(strings.ToUpper(v[0])) | ||||
| 			if len(v) >= 2 { | ||||
| 				setting[k] = strings.Join(v[1:], ":") | ||||
| 			} else { | ||||
| 				setting[k] = k | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return setting | ||||
| } | ||||
| @ -1,93 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| type ModelA struct { | ||||
| 	gorm.Model | ||||
| 	Name string | ||||
| 
 | ||||
| 	ModelCs []ModelC `gorm:"foreignkey:OtherAID"` | ||||
| } | ||||
| 
 | ||||
| type ModelB struct { | ||||
| 	gorm.Model | ||||
| 	Name string | ||||
| 
 | ||||
| 	ModelCs []ModelC `gorm:"foreignkey:OtherBID"` | ||||
| } | ||||
| 
 | ||||
| type ModelC struct { | ||||
| 	gorm.Model | ||||
| 	Name string | ||||
| 
 | ||||
| 	OtherAID uint64 | ||||
| 	OtherA   *ModelA `gorm:"foreignkey:OtherAID"` | ||||
| 	OtherBID uint64 | ||||
| 	OtherB   *ModelB `gorm:"foreignkey:OtherBID"` | ||||
| } | ||||
| 
 | ||||
| // This test will try to cause a race condition on the model's foreignkey metadata
 | ||||
| func TestModelStructRaceSameModel(t *testing.T) { | ||||
| 	// use a WaitGroup to execute as much in-sync as possible
 | ||||
| 	// it's more likely to hit a race condition than without
 | ||||
| 	n := 32 | ||||
| 	start := sync.WaitGroup{} | ||||
| 	start.Add(n) | ||||
| 
 | ||||
| 	// use another WaitGroup to know when the test is done
 | ||||
| 	done := sync.WaitGroup{} | ||||
| 	done.Add(n) | ||||
| 
 | ||||
| 	for i := 0; i < n; i++ { | ||||
| 		go func() { | ||||
| 			start.Wait() | ||||
| 
 | ||||
| 			// call GetStructFields, this had a race condition before we fixed it
 | ||||
| 			DB.NewScope(&ModelA{}).GetStructFields() | ||||
| 
 | ||||
| 			done.Done() | ||||
| 		}() | ||||
| 
 | ||||
| 		start.Done() | ||||
| 	} | ||||
| 
 | ||||
| 	done.Wait() | ||||
| } | ||||
| 
 | ||||
| // This test will try to cause a race condition on the model's foreignkey metadata
 | ||||
| func TestModelStructRaceDifferentModel(t *testing.T) { | ||||
| 	// use a WaitGroup to execute as much in-sync as possible
 | ||||
| 	// it's more likely to hit a race condition than without
 | ||||
| 	n := 32 | ||||
| 	start := sync.WaitGroup{} | ||||
| 	start.Add(n) | ||||
| 
 | ||||
| 	// use another WaitGroup to know when the test is done
 | ||||
| 	done := sync.WaitGroup{} | ||||
| 	done.Add(n) | ||||
| 
 | ||||
| 	for i := 0; i < n; i++ { | ||||
| 		i := i | ||||
| 		go func() { | ||||
| 			start.Wait() | ||||
| 
 | ||||
| 			// call GetStructFields, this had a race condition before we fixed it
 | ||||
| 			if i%2 == 0 { | ||||
| 				DB.NewScope(&ModelA{}).GetStructFields() | ||||
| 			} else { | ||||
| 				DB.NewScope(&ModelB{}).GetStructFields() | ||||
| 			} | ||||
| 
 | ||||
| 			done.Done() | ||||
| 		}() | ||||
| 
 | ||||
| 		start.Done() | ||||
| 	} | ||||
| 
 | ||||
| 	done.Wait() | ||||
| } | ||||
| @ -1,381 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| type Blog struct { | ||||
| 	ID         uint   `gorm:"primary_key"` | ||||
| 	Locale     string `gorm:"primary_key"` | ||||
| 	Subject    string | ||||
| 	Body       string | ||||
| 	Tags       []Tag `gorm:"many2many:blog_tags;"` | ||||
| 	SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` | ||||
| 	LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` | ||||
| } | ||||
| 
 | ||||
| type Tag struct { | ||||
| 	ID     uint   `gorm:"primary_key"` | ||||
| 	Locale string `gorm:"primary_key"` | ||||
| 	Value  string | ||||
| 	Blogs  []*Blog `gorm:"many2many:blogs_tags"` | ||||
| } | ||||
| 
 | ||||
| func compareTags(tags []Tag, contents []string) bool { | ||||
| 	var tagContents []string | ||||
| 	for _, tag := range tags { | ||||
| 		tagContents = append(tagContents, tag.Value) | ||||
| 	} | ||||
| 	sort.Strings(tagContents) | ||||
| 	sort.Strings(contents) | ||||
| 	return reflect.DeepEqual(tagContents, contents) | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { | ||||
| 		DB.DropTable(&Blog{}, &Tag{}) | ||||
| 		DB.DropTable("blog_tags") | ||||
| 		DB.CreateTable(&Blog{}, &Tag{}) | ||||
| 		blog := Blog{ | ||||
| 			Locale:  "ZH", | ||||
| 			Subject: "subject", | ||||
| 			Body:    "body", | ||||
| 			Tags: []Tag{ | ||||
| 				{Locale: "ZH", Value: "tag1"}, | ||||
| 				{Locale: "ZH", Value: "tag2"}, | ||||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Save(&blog) | ||||
| 		if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { | ||||
| 			t.Errorf("Blog should has two tags") | ||||
| 		} | ||||
| 
 | ||||
| 		// Append
 | ||||
| 		var tag3 = &Tag{Locale: "ZH", Value: "tag3"} | ||||
| 		DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) | ||||
| 		if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("Tags").Count() != 3 { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		var tags []Tag | ||||
| 		DB.Model(&blog).Related(&tags, "Tags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related") | ||||
| 		} | ||||
| 
 | ||||
| 		var blog1 Blog | ||||
| 		DB.Preload("Tags").Find(&blog1) | ||||
| 		if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Preload many2many relations") | ||||
| 		} | ||||
| 
 | ||||
| 		// Replace
 | ||||
| 		var tag5 = &Tag{Locale: "ZH", Value: "tag5"} | ||||
| 		var tag6 = &Tag{Locale: "ZH", Value: "tag6"} | ||||
| 		DB.Model(&blog).Association("Tags").Replace(tag5, tag6) | ||||
| 		var tags2 []Tag | ||||
| 		DB.Model(&blog).Related(&tags2, "Tags") | ||||
| 		if !compareTags(tags2, []string{"tag5", "tag6"}) { | ||||
| 			t.Errorf("Should find 2 tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("Tags").Count() != 2 { | ||||
| 			t.Errorf("Blog should has three tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		// Delete
 | ||||
| 		DB.Model(&blog).Association("Tags").Delete(tag5) | ||||
| 		var tags3 []Tag | ||||
| 		DB.Model(&blog).Related(&tags3, "Tags") | ||||
| 		if !compareTags(tags3, []string{"tag6"}) { | ||||
| 			t.Errorf("Should find 1 tags after Delete") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("Tags").Count() != 1 { | ||||
| 			t.Errorf("Blog should has three tags after Delete") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog).Association("Tags").Delete(tag3) | ||||
| 		var tags4 []Tag | ||||
| 		DB.Model(&blog).Related(&tags4, "Tags") | ||||
| 		if !compareTags(tags4, []string{"tag6"}) { | ||||
| 			t.Errorf("Tag should not be deleted when Delete with a unrelated tag") | ||||
| 		} | ||||
| 
 | ||||
| 		// Clear
 | ||||
| 		DB.Model(&blog).Association("Tags").Clear() | ||||
| 		if DB.Model(&blog).Association("Tags").Count() != 0 { | ||||
| 			t.Errorf("All tags should be cleared") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { | ||||
| 		DB.DropTable(&Blog{}, &Tag{}) | ||||
| 		DB.DropTable("shared_blog_tags") | ||||
| 		DB.CreateTable(&Blog{}, &Tag{}) | ||||
| 		blog := Blog{ | ||||
| 			Locale:  "ZH", | ||||
| 			Subject: "subject", | ||||
| 			Body:    "body", | ||||
| 			SharedTags: []Tag{ | ||||
| 				{Locale: "ZH", Value: "tag1"}, | ||||
| 				{Locale: "ZH", Value: "tag2"}, | ||||
| 			}, | ||||
| 		} | ||||
| 		DB.Save(&blog) | ||||
| 
 | ||||
| 		blog2 := Blog{ | ||||
| 			ID:     blog.ID, | ||||
| 			Locale: "EN", | ||||
| 		} | ||||
| 		DB.Create(&blog2) | ||||
| 
 | ||||
| 		if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { | ||||
| 			t.Errorf("Blog should has two tags") | ||||
| 		} | ||||
| 
 | ||||
| 		// Append
 | ||||
| 		var tag3 = &Tag{Locale: "ZH", Value: "tag3"} | ||||
| 		DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) | ||||
| 		if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("SharedTags").Count() != 3 { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("SharedTags").Count() != 3 { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		var tags []Tag | ||||
| 		DB.Model(&blog).Related(&tags, "SharedTags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Related(&tags, "SharedTags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related") | ||||
| 		} | ||||
| 
 | ||||
| 		var blog1 Blog | ||||
| 		DB.Preload("SharedTags").Find(&blog1) | ||||
| 		if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Preload many2many relations") | ||||
| 		} | ||||
| 
 | ||||
| 		var tag4 = &Tag{Locale: "ZH", Value: "tag4"} | ||||
| 		DB.Model(&blog2).Association("SharedTags").Append(tag4) | ||||
| 
 | ||||
| 		DB.Model(&blog).Related(&tags, "SharedTags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Related(&tags, "SharedTags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related") | ||||
| 		} | ||||
| 
 | ||||
| 		// Replace
 | ||||
| 		var tag5 = &Tag{Locale: "ZH", Value: "tag5"} | ||||
| 		var tag6 = &Tag{Locale: "ZH", Value: "tag6"} | ||||
| 		DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) | ||||
| 		var tags2 []Tag | ||||
| 		DB.Model(&blog).Related(&tags2, "SharedTags") | ||||
| 		if !compareTags(tags2, []string{"tag5", "tag6"}) { | ||||
| 			t.Errorf("Should find 2 tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Related(&tags2, "SharedTags") | ||||
| 		if !compareTags(tags2, []string{"tag5", "tag6"}) { | ||||
| 			t.Errorf("Should find 2 tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("SharedTags").Count() != 2 { | ||||
| 			t.Errorf("Blog should has three tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		// Delete
 | ||||
| 		DB.Model(&blog).Association("SharedTags").Delete(tag5) | ||||
| 		var tags3 []Tag | ||||
| 		DB.Model(&blog).Related(&tags3, "SharedTags") | ||||
| 		if !compareTags(tags3, []string{"tag6"}) { | ||||
| 			t.Errorf("Should find 1 tags after Delete") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("SharedTags").Count() != 1 { | ||||
| 			t.Errorf("Blog should has three tags after Delete") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Association("SharedTags").Delete(tag3) | ||||
| 		var tags4 []Tag | ||||
| 		DB.Model(&blog).Related(&tags4, "SharedTags") | ||||
| 		if !compareTags(tags4, []string{"tag6"}) { | ||||
| 			t.Errorf("Tag should not be deleted when Delete with a unrelated tag") | ||||
| 		} | ||||
| 
 | ||||
| 		// Clear
 | ||||
| 		DB.Model(&blog2).Association("SharedTags").Clear() | ||||
| 		if DB.Model(&blog).Association("SharedTags").Count() != 0 { | ||||
| 			t.Errorf("All tags should be cleared") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { | ||||
| 	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { | ||||
| 		DB.DropTable(&Blog{}, &Tag{}) | ||||
| 		DB.DropTable("locale_blog_tags") | ||||
| 		DB.CreateTable(&Blog{}, &Tag{}) | ||||
| 		blog := Blog{ | ||||
| 			Locale:  "ZH", | ||||
| 			Subject: "subject", | ||||
| 			Body:    "body", | ||||
| 			LocaleTags: []Tag{ | ||||
| 				{Locale: "ZH", Value: "tag1"}, | ||||
| 				{Locale: "ZH", Value: "tag2"}, | ||||
| 			}, | ||||
| 		} | ||||
| 		DB.Save(&blog) | ||||
| 
 | ||||
| 		blog2 := Blog{ | ||||
| 			ID:     blog.ID, | ||||
| 			Locale: "EN", | ||||
| 		} | ||||
| 		DB.Create(&blog2) | ||||
| 
 | ||||
| 		// Append
 | ||||
| 		var tag3 = &Tag{Locale: "ZH", Value: "tag3"} | ||||
| 		DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) | ||||
| 		if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | ||||
| 			t.Errorf("Blog should has three tags after Append") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { | ||||
| 			t.Errorf("EN Blog should has 0 tags after ZH Blog Append") | ||||
| 		} | ||||
| 
 | ||||
| 		var tags []Tag | ||||
| 		DB.Model(&blog).Related(&tags, "LocaleTags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Related(&tags, "LocaleTags") | ||||
| 		if len(tags) != 0 { | ||||
| 			t.Errorf("Should find 0 tags with Related for EN Blog") | ||||
| 		} | ||||
| 
 | ||||
| 		var blog1 Blog | ||||
| 		DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) | ||||
| 		if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Preload many2many relations") | ||||
| 		} | ||||
| 
 | ||||
| 		var tag4 = &Tag{Locale: "ZH", Value: "tag4"} | ||||
| 		DB.Model(&blog2).Association("LocaleTags").Append(tag4) | ||||
| 
 | ||||
| 		DB.Model(&blog).Related(&tags, "LocaleTags") | ||||
| 		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("Should find 3 tags with Related for EN Blog") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Related(&tags, "LocaleTags") | ||||
| 		if !compareTags(tags, []string{"tag4"}) { | ||||
| 			t.Errorf("Should find 1 tags with Related for EN Blog") | ||||
| 		} | ||||
| 
 | ||||
| 		// Replace
 | ||||
| 		var tag5 = &Tag{Locale: "ZH", Value: "tag5"} | ||||
| 		var tag6 = &Tag{Locale: "ZH", Value: "tag6"} | ||||
| 		DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) | ||||
| 
 | ||||
| 		var tags2 []Tag | ||||
| 		DB.Model(&blog).Related(&tags2, "LocaleTags") | ||||
| 		if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		var blog11 Blog | ||||
| 		DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) | ||||
| 		if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { | ||||
| 			t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Related(&tags2, "LocaleTags") | ||||
| 		if !compareTags(tags2, []string{"tag5", "tag6"}) { | ||||
| 			t.Errorf("Should find 2 tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		var blog21 Blog | ||||
| 		DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) | ||||
| 		if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { | ||||
| 			t.Errorf("EN Blog's tags should be changed after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | ||||
| 			t.Errorf("ZH Blog should has three tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { | ||||
| 			t.Errorf("EN Blog should has two tags after Replace") | ||||
| 		} | ||||
| 
 | ||||
| 		// Delete
 | ||||
| 		DB.Model(&blog).Association("LocaleTags").Delete(tag5) | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | ||||
| 			t.Errorf("ZH Blog should has three tags after Delete with EN's tag") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { | ||||
| 			t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog2).Association("LocaleTags").Delete(tag5) | ||||
| 
 | ||||
| 		if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | ||||
| 			t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { | ||||
| 			t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") | ||||
| 		} | ||||
| 
 | ||||
| 		// Clear
 | ||||
| 		DB.Model(&blog2).Association("LocaleTags").Clear() | ||||
| 		if DB.Model(&blog).Association("LocaleTags").Count() != 3 { | ||||
| 			t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { | ||||
| 			t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") | ||||
| 		} | ||||
| 
 | ||||
| 		DB.Model(&blog).Association("LocaleTags").Clear() | ||||
| 		if DB.Model(&blog).Association("LocaleTags").Count() != 0 { | ||||
| 			t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") | ||||
| 		} | ||||
| 
 | ||||
| 		if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { | ||||
| 			t.Errorf("EN Blog's tags should be cleared") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										124
									
								
								naming.go
									
									
									
									
									
								
							
							
						
						
									
										124
									
								
								naming.go
									
									
									
									
									
								
							| @ -1,124 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Namer is a function type which is given a string and return a string
 | ||||
| type Namer func(string) string | ||||
| 
 | ||||
| // NamingStrategy represents naming strategies
 | ||||
| type NamingStrategy struct { | ||||
| 	DB     Namer | ||||
| 	Table  Namer | ||||
| 	Column Namer | ||||
| } | ||||
| 
 | ||||
| // TheNamingStrategy is being initialized with defaultNamingStrategy
 | ||||
| var TheNamingStrategy = &NamingStrategy{ | ||||
| 	DB:     defaultNamer, | ||||
| 	Table:  defaultNamer, | ||||
| 	Column: defaultNamer, | ||||
| } | ||||
| 
 | ||||
| // AddNamingStrategy sets the naming strategy
 | ||||
| func AddNamingStrategy(ns *NamingStrategy) { | ||||
| 	if ns.DB == nil { | ||||
| 		ns.DB = defaultNamer | ||||
| 	} | ||||
| 	if ns.Table == nil { | ||||
| 		ns.Table = defaultNamer | ||||
| 	} | ||||
| 	if ns.Column == nil { | ||||
| 		ns.Column = defaultNamer | ||||
| 	} | ||||
| 	TheNamingStrategy = ns | ||||
| } | ||||
| 
 | ||||
| // DBName alters the given name by DB
 | ||||
| func (ns *NamingStrategy) DBName(name string) string { | ||||
| 	return ns.DB(name) | ||||
| } | ||||
| 
 | ||||
| // TableName alters the given name by Table
 | ||||
| func (ns *NamingStrategy) TableName(name string) string { | ||||
| 	return ns.Table(name) | ||||
| } | ||||
| 
 | ||||
| // ColumnName alters the given name by Column
 | ||||
| func (ns *NamingStrategy) ColumnName(name string) string { | ||||
| 	return ns.Column(name) | ||||
| } | ||||
| 
 | ||||
| // ToDBName convert string to db name
 | ||||
| func ToDBName(name string) string { | ||||
| 	return TheNamingStrategy.DBName(name) | ||||
| } | ||||
| 
 | ||||
| // ToTableName convert string to table name
 | ||||
| func ToTableName(name string) string { | ||||
| 	return TheNamingStrategy.TableName(name) | ||||
| } | ||||
| 
 | ||||
| // ToColumnName convert string to db name
 | ||||
| func ToColumnName(name string) string { | ||||
| 	return TheNamingStrategy.ColumnName(name) | ||||
| } | ||||
| 
 | ||||
| var smap = newSafeMap() | ||||
| 
 | ||||
| func defaultNamer(name string) string { | ||||
| 	const ( | ||||
| 		lower = false | ||||
| 		upper = true | ||||
| 	) | ||||
| 
 | ||||
| 	if v := smap.Get(name); v != "" { | ||||
| 		return v | ||||
| 	} | ||||
| 
 | ||||
| 	if name == "" { | ||||
| 		return "" | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		value                                    = commonInitialismsReplacer.Replace(name) | ||||
| 		buf                                      = bytes.NewBufferString("") | ||||
| 		lastCase, currCase, nextCase, nextNumber bool | ||||
| 	) | ||||
| 
 | ||||
| 	for i, v := range value[:len(value)-1] { | ||||
| 		nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') | ||||
| 		nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') | ||||
| 
 | ||||
| 		if i > 0 { | ||||
| 			if currCase == upper { | ||||
| 				if lastCase == upper && (nextCase == upper || nextNumber == upper) { | ||||
| 					buf.WriteRune(v) | ||||
| 				} else { | ||||
| 					if value[i-1] != '_' && value[i+1] != '_' { | ||||
| 						buf.WriteRune('_') | ||||
| 					} | ||||
| 					buf.WriteRune(v) | ||||
| 				} | ||||
| 			} else { | ||||
| 				buf.WriteRune(v) | ||||
| 				if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { | ||||
| 					buf.WriteRune('_') | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			currCase = upper | ||||
| 			buf.WriteRune(v) | ||||
| 		} | ||||
| 		lastCase = currCase | ||||
| 		currCase = nextCase | ||||
| 	} | ||||
| 
 | ||||
| 	buf.WriteByte(value[len(value)-1]) | ||||
| 
 | ||||
| 	s := strings.ToLower(buf.String()) | ||||
| 	smap.Set(name, s) | ||||
| 	return s | ||||
| } | ||||
| @ -1,69 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func TestTheNamingStrategy(t *testing.T) { | ||||
| 
 | ||||
| 	cases := []struct { | ||||
| 		name     string | ||||
| 		namer    gorm.Namer | ||||
| 		expected string | ||||
| 	}{ | ||||
| 		{name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, | ||||
| 		{name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, | ||||
| 		{name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, c := range cases { | ||||
| 		t.Run(c.name, func(t *testing.T) { | ||||
| 			result := c.namer(c.name) | ||||
| 			if result != c.expected { | ||||
| 				t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestNamingStrategy(t *testing.T) { | ||||
| 
 | ||||
| 	dbNameNS := func(name string) string { | ||||
| 		return "db_" + name | ||||
| 	} | ||||
| 	tableNameNS := func(name string) string { | ||||
| 		return "tbl_" + name | ||||
| 	} | ||||
| 	columnNameNS := func(name string) string { | ||||
| 		return "col_" + name | ||||
| 	} | ||||
| 
 | ||||
| 	ns := &gorm.NamingStrategy{ | ||||
| 		DB:     dbNameNS, | ||||
| 		Table:  tableNameNS, | ||||
| 		Column: columnNameNS, | ||||
| 	} | ||||
| 
 | ||||
| 	cases := []struct { | ||||
| 		name     string | ||||
| 		namer    gorm.Namer | ||||
| 		expected string | ||||
| 	}{ | ||||
| 		{name: "auth", expected: "db_auth", namer: ns.DB}, | ||||
| 		{name: "user", expected: "tbl_user", namer: ns.Table}, | ||||
| 		{name: "password", expected: "col_password", namer: ns.Column}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, c := range cases { | ||||
| 		t.Run(c.name, func(t *testing.T) { | ||||
| 			result := c.namer(c.name) | ||||
| 			if result != c.expected { | ||||
| 				t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| @ -1,84 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import "testing" | ||||
| 
 | ||||
| type PointerStruct struct { | ||||
| 	ID   int64 | ||||
| 	Name *string | ||||
| 	Num  *int | ||||
| } | ||||
| 
 | ||||
| type NormalStruct struct { | ||||
| 	ID   int64 | ||||
| 	Name string | ||||
| 	Num  int | ||||
| } | ||||
| 
 | ||||
| func TestPointerFields(t *testing.T) { | ||||
| 	DB.DropTable(&PointerStruct{}) | ||||
| 	DB.AutoMigrate(&PointerStruct{}) | ||||
| 	var name = "pointer struct 1" | ||||
| 	var num = 100 | ||||
| 	pointerStruct := PointerStruct{Name: &name, Num: &num} | ||||
| 	if DB.Create(&pointerStruct).Error != nil { | ||||
| 		t.Errorf("Failed to save pointer struct") | ||||
| 	} | ||||
| 
 | ||||
| 	var pointerStructResult PointerStruct | ||||
| 	if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { | ||||
| 		t.Errorf("Failed to query saved pointer struct") | ||||
| 	} | ||||
| 
 | ||||
| 	var tableName = DB.NewScope(&PointerStruct{}).TableName() | ||||
| 
 | ||||
| 	var normalStruct NormalStruct | ||||
| 	DB.Table(tableName).First(&normalStruct) | ||||
| 	if normalStruct.Name != name || normalStruct.Num != num { | ||||
| 		t.Errorf("Failed to query saved Normal struct") | ||||
| 	} | ||||
| 
 | ||||
| 	var nilPointerStruct = PointerStruct{} | ||||
| 	if err := DB.Create(&nilPointerStruct).Error; err != nil { | ||||
| 		t.Error("Failed to save nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var pointerStruct2 PointerStruct | ||||
| 	if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { | ||||
| 		t.Error("Failed to query saved nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var normalStruct2 NormalStruct | ||||
| 	if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { | ||||
| 		t.Error("Failed to query saved nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var partialNilPointerStruct1 = PointerStruct{Num: &num} | ||||
| 	if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { | ||||
| 		t.Error("Failed to save partial nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var pointerStruct3 PointerStruct | ||||
| 	if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { | ||||
| 		t.Error("Failed to query saved partial nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var normalStruct3 NormalStruct | ||||
| 	if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { | ||||
| 		t.Error("Failed to query saved partial pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var partialNilPointerStruct2 = PointerStruct{Name: &name} | ||||
| 	if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { | ||||
| 		t.Error("Failed to save partial nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var pointerStruct4 PointerStruct | ||||
| 	if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { | ||||
| 		t.Error("Failed to query saved partial nil pointer struct", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var normalStruct4 NormalStruct | ||||
| 	if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { | ||||
| 		t.Error("Failed to query saved partial pointer struct", err) | ||||
| 	} | ||||
| } | ||||
| @ -1,366 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| type Cat struct { | ||||
| 	Id   int | ||||
| 	Name string | ||||
| 	Toy  Toy `gorm:"polymorphic:Owner;"` | ||||
| } | ||||
| 
 | ||||
| type Dog struct { | ||||
| 	Id   int | ||||
| 	Name string | ||||
| 	Toys []Toy `gorm:"polymorphic:Owner;"` | ||||
| } | ||||
| 
 | ||||
| type Hamster struct { | ||||
| 	Id           int | ||||
| 	Name         string | ||||
| 	PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` | ||||
| 	OtherToy     Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` | ||||
| } | ||||
| 
 | ||||
| type Toy struct { | ||||
| 	Id        int | ||||
| 	Name      string | ||||
| 	OwnerId   int | ||||
| 	OwnerType string | ||||
| } | ||||
| 
 | ||||
| var compareToys = func(toys []Toy, contents []string) bool { | ||||
| 	var toyContents []string | ||||
| 	for _, toy := range toys { | ||||
| 		toyContents = append(toyContents, toy.Name) | ||||
| 	} | ||||
| 	sort.Strings(toyContents) | ||||
| 	sort.Strings(contents) | ||||
| 	return reflect.DeepEqual(toyContents, contents) | ||||
| } | ||||
| 
 | ||||
| func TestPolymorphic(t *testing.T) { | ||||
| 	cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} | ||||
| 	dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} | ||||
| 	DB.Save(&cat).Save(&dog) | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 1 { | ||||
| 		t.Errorf("Cat's toys count should be 1") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 2 { | ||||
| 		t.Errorf("Dog's toys count should be 2") | ||||
| 	} | ||||
| 
 | ||||
| 	// Query
 | ||||
| 	var catToys []Toy | ||||
| 	if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { | ||||
| 		t.Errorf("Did not find any has one polymorphic association") | ||||
| 	} else if len(catToys) != 1 { | ||||
| 		t.Errorf("Should have found only one polymorphic has one association") | ||||
| 	} else if catToys[0].Name != cat.Toy.Name { | ||||
| 		t.Errorf("Should have found the proper has one polymorphic association") | ||||
| 	} | ||||
| 
 | ||||
| 	var dogToys []Toy | ||||
| 	if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { | ||||
| 		t.Errorf("Did not find any polymorphic has many associations") | ||||
| 	} else if len(dogToys) != len(dog.Toys) { | ||||
| 		t.Errorf("Should have found all polymorphic has many associations") | ||||
| 	} | ||||
| 
 | ||||
| 	var catToy Toy | ||||
| 	DB.Model(&cat).Association("Toy").Find(&catToy) | ||||
| 	if catToy.Name != cat.Toy.Name { | ||||
| 		t.Errorf("Should find has one polymorphic association") | ||||
| 	} | ||||
| 
 | ||||
| 	var dogToys1 []Toy | ||||
| 	DB.Model(&dog).Association("Toys").Find(&dogToys1) | ||||
| 	if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { | ||||
| 		t.Errorf("Should find has many polymorphic association") | ||||
| 	} | ||||
| 
 | ||||
| 	// Append
 | ||||
| 	DB.Model(&cat).Association("Toy").Append(&Toy{ | ||||
| 		Name: "cat toy 2", | ||||
| 	}) | ||||
| 
 | ||||
| 	var catToy2 Toy | ||||
| 	DB.Model(&cat).Association("Toy").Find(&catToy2) | ||||
| 	if catToy2.Name != "cat toy 2" { | ||||
| 		t.Errorf("Should update has one polymorphic association with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 1 { | ||||
| 		t.Errorf("Cat's toys count should be 1 after Append") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 2 { | ||||
| 		t.Errorf("Should return two polymorphic has many associations") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&dog).Association("Toys").Append(&Toy{ | ||||
| 		Name: "dog toy 3", | ||||
| 	}) | ||||
| 
 | ||||
| 	var dogToys2 []Toy | ||||
| 	DB.Model(&dog).Association("Toys").Find(&dogToys2) | ||||
| 	if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { | ||||
| 		t.Errorf("Dog's toys should be updated with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 3 { | ||||
| 		t.Errorf("Should return three polymorphic has many associations") | ||||
| 	} | ||||
| 
 | ||||
| 	// Replace
 | ||||
| 	DB.Model(&cat).Association("Toy").Replace(&Toy{ | ||||
| 		Name: "cat toy 3", | ||||
| 	}) | ||||
| 
 | ||||
| 	var catToy3 Toy | ||||
| 	DB.Model(&cat).Association("Toy").Find(&catToy3) | ||||
| 	if catToy3.Name != "cat toy 3" { | ||||
| 		t.Errorf("Should update has one polymorphic association with Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 1 { | ||||
| 		t.Errorf("Cat's toys count should be 1 after Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 3 { | ||||
| 		t.Errorf("Should return three polymorphic has many associations") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&dog).Association("Toys").Replace(&Toy{ | ||||
| 		Name: "dog toy 4", | ||||
| 	}, []Toy{ | ||||
| 		{Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, | ||||
| 	}) | ||||
| 
 | ||||
| 	var dogToys3 []Toy | ||||
| 	DB.Model(&dog).Association("Toys").Find(&dogToys3) | ||||
| 	if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { | ||||
| 		t.Errorf("Dog's toys should be updated with Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 4 { | ||||
| 		t.Errorf("Should return three polymorphic has many associations") | ||||
| 	} | ||||
| 
 | ||||
| 	// Delete
 | ||||
| 	DB.Model(&cat).Association("Toy").Delete(&catToy2) | ||||
| 
 | ||||
| 	var catToy4 Toy | ||||
| 	DB.Model(&cat).Association("Toy").Find(&catToy4) | ||||
| 	if catToy4.Name != "cat toy 3" { | ||||
| 		t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 1 { | ||||
| 		t.Errorf("Cat's toys count should be 1") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 4 { | ||||
| 		t.Errorf("Dog's toys count should be 4") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&cat).Association("Toy").Delete(&catToy3) | ||||
| 
 | ||||
| 	if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { | ||||
| 		t.Errorf("Toy should be deleted with Delete") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 0 { | ||||
| 		t.Errorf("Cat's toys count should be 0 after Delete") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 4 { | ||||
| 		t.Errorf("Dog's toys count should not be changed when delete cat's toy") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&dog).Association("Toys").Delete(&dogToys2) | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 4 { | ||||
| 		t.Errorf("Dog's toys count should not be changed when delete unrelated toys") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&dog).Association("Toys").Delete(&dogToys3) | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 0 { | ||||
| 		t.Errorf("Dog's toys count should be deleted with Delete") | ||||
| 	} | ||||
| 
 | ||||
| 	// Clear
 | ||||
| 	DB.Model(&cat).Association("Toy").Append(&Toy{ | ||||
| 		Name: "cat toy 2", | ||||
| 	}) | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 1 { | ||||
| 		t.Errorf("Cat's toys should be added with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&cat).Association("Toy").Clear() | ||||
| 
 | ||||
| 	if DB.Model(&cat).Association("Toy").Count() != 0 { | ||||
| 		t.Errorf("Cat's toys should be cleared with Clear") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&dog).Association("Toys").Append(&Toy{ | ||||
| 		Name: "dog toy 8", | ||||
| 	}) | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 1 { | ||||
| 		t.Errorf("Dog's toys should be added with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&dog).Association("Toys").Clear() | ||||
| 
 | ||||
| 	if DB.Model(&dog).Association("Toys").Count() != 0 { | ||||
| 		t.Errorf("Dog's toys should be cleared with Clear") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNamedPolymorphic(t *testing.T) { | ||||
| 	hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} | ||||
| 	DB.Save(&hamster) | ||||
| 
 | ||||
| 	hamster2 := Hamster{} | ||||
| 	DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) | ||||
| 	if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { | ||||
| 		t.Errorf("Hamster's preferred toy couldn't be preloaded") | ||||
| 	} | ||||
| 	if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { | ||||
| 		t.Errorf("Hamster's other toy couldn't be preloaded") | ||||
| 	} | ||||
| 
 | ||||
| 	// clear to omit Toy.Id in count
 | ||||
| 	hamster2.PreferredToy = Toy{} | ||||
| 	hamster2.OtherToy = Toy{} | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's preferred toy count should be 1") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's other toy count should be 1") | ||||
| 	} | ||||
| 
 | ||||
| 	// Query
 | ||||
| 	var hamsterToys []Toy | ||||
| 	if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { | ||||
| 		t.Errorf("Did not find any has one polymorphic association") | ||||
| 	} else if len(hamsterToys) != 1 { | ||||
| 		t.Errorf("Should have found only one polymorphic has one association") | ||||
| 	} else if hamsterToys[0].Name != hamster.PreferredToy.Name { | ||||
| 		t.Errorf("Should have found the proper has one polymorphic association") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { | ||||
| 		t.Errorf("Did not find any has one polymorphic association") | ||||
| 	} else if len(hamsterToys) != 1 { | ||||
| 		t.Errorf("Should have found only one polymorphic has one association") | ||||
| 	} else if hamsterToys[0].Name != hamster.OtherToy.Name { | ||||
| 		t.Errorf("Should have found the proper has one polymorphic association") | ||||
| 	} | ||||
| 
 | ||||
| 	hamsterToy := Toy{} | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) | ||||
| 	if hamsterToy.Name != hamster.PreferredToy.Name { | ||||
| 		t.Errorf("Should find has one polymorphic association") | ||||
| 	} | ||||
| 	hamsterToy = Toy{} | ||||
| 	DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) | ||||
| 	if hamsterToy.Name != hamster.OtherToy.Name { | ||||
| 		t.Errorf("Should find has one polymorphic association") | ||||
| 	} | ||||
| 
 | ||||
| 	// Append
 | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ | ||||
| 		Name: "bike 2", | ||||
| 	}) | ||||
| 	DB.Model(&hamster).Association("OtherToy").Append(&Toy{ | ||||
| 		Name: "treadmill 2", | ||||
| 	}) | ||||
| 
 | ||||
| 	hamsterToy = Toy{} | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) | ||||
| 	if hamsterToy.Name != "bike 2" { | ||||
| 		t.Errorf("Should update has one polymorphic association with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	hamsterToy = Toy{} | ||||
| 	DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) | ||||
| 	if hamsterToy.Name != "treadmill 2" { | ||||
| 		t.Errorf("Should update has one polymorphic association with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's toys count should be 1 after Append") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's toys count should be 1 after Append") | ||||
| 	} | ||||
| 
 | ||||
| 	// Replace
 | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ | ||||
| 		Name: "bike 3", | ||||
| 	}) | ||||
| 	DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ | ||||
| 		Name: "treadmill 3", | ||||
| 	}) | ||||
| 
 | ||||
| 	hamsterToy = Toy{} | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) | ||||
| 	if hamsterToy.Name != "bike 3" { | ||||
| 		t.Errorf("Should update has one polymorphic association with Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	hamsterToy = Toy{} | ||||
| 	DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) | ||||
| 	if hamsterToy.Name != "treadmill 3" { | ||||
| 		t.Errorf("Should update has one polymorphic association with Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { | ||||
| 		t.Errorf("hamster's toys count should be 1 after Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | ||||
| 		t.Errorf("hamster's toys count should be 1 after Replace") | ||||
| 	} | ||||
| 
 | ||||
| 	// Clear
 | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ | ||||
| 		Name: "bike 2", | ||||
| 	}) | ||||
| 	DB.Model(&hamster).Association("OtherToy").Append(&Toy{ | ||||
| 		Name: "treadmill 2", | ||||
| 	}) | ||||
| 
 | ||||
| 	if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's toys should be added with Append") | ||||
| 	} | ||||
| 	if DB.Model(&hamster).Association("OtherToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's toys should be added with Append") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&hamster).Association("PreferredToy").Clear() | ||||
| 
 | ||||
| 	if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { | ||||
| 		t.Errorf("Hamster's preferred toy should be cleared with Clear") | ||||
| 	} | ||||
| 	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { | ||||
| 		t.Errorf("Hamster's other toy should be still available") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&hamster).Association("OtherToy").Clear() | ||||
| 	if DB.Model(&hamster).Association("OtherToy").Count() != 0 { | ||||
| 		t.Errorf("Hamster's other toy should be cleared with Clear") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										1701
									
								
								preload_test.go
									
									
									
									
									
								
							
							
						
						
									
										1701
									
								
								preload_test.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										841
									
								
								query_test.go
									
									
									
									
									
								
							
							
						
						
									
										841
									
								
								query_test.go
									
									
									
									
									
								
							| @ -1,841 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 
 | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestFirstAndLast(t *testing.T) { | ||||
| 	DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) | ||||
| 	DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) | ||||
| 
 | ||||
| 	var user1, user2, user3, user4 User | ||||
| 	DB.First(&user1) | ||||
| 	DB.Order("id").Limit(1).Find(&user2) | ||||
| 
 | ||||
| 	ptrOfUser3 := &user3 | ||||
| 	DB.Last(&ptrOfUser3) | ||||
| 	DB.Order("id desc").Limit(1).Find(&user4) | ||||
| 	if user1.Id != user2.Id || user3.Id != user4.Id { | ||||
| 		t.Errorf("First and Last should by order by primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	var users []User | ||||
| 	DB.First(&users) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Find first record as slice") | ||||
| 	} | ||||
| 
 | ||||
| 	var user User | ||||
| 	if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { | ||||
| 		t.Errorf("Should not raise any error when order with Join table") | ||||
| 	} | ||||
| 
 | ||||
| 	if user.Email != "" { | ||||
| 		t.Errorf("User's Email should be blank as no one set it") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { | ||||
| 	DB.Save(&Animal{Name: "animal1"}) | ||||
| 	DB.Save(&Animal{Name: "animal2"}) | ||||
| 
 | ||||
| 	var animal1, animal2, animal3, animal4 Animal | ||||
| 	DB.First(&animal1) | ||||
| 	DB.Order("counter").Limit(1).Find(&animal2) | ||||
| 
 | ||||
| 	DB.Last(&animal3) | ||||
| 	DB.Order("counter desc").Limit(1).Find(&animal4) | ||||
| 	if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { | ||||
| 		t.Errorf("First and Last should work correctly") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFirstAndLastWithRaw(t *testing.T) { | ||||
| 	user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}} | ||||
| 	user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}} | ||||
| 	DB.Save(&user1) | ||||
| 	DB.Save(&user2) | ||||
| 
 | ||||
| 	var user3, user4 User | ||||
| 	DB.Raw("select * from users WHERE name = ?", "user").First(&user3) | ||||
| 	if user3.Id != user1.Id { | ||||
| 		t.Errorf("Find first record with raw") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Raw("select * from users WHERE name = ?", "user").Last(&user4) | ||||
| 	if user4.Id != user2.Id { | ||||
| 		t.Errorf("Find last record with raw") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUIntPrimaryKey(t *testing.T) { | ||||
| 	var animal Animal | ||||
| 	DB.First(&animal, uint64(1)) | ||||
| 	if animal.Counter != 1 { | ||||
| 		t.Errorf("Fetch a record from with a non-int primary key should work, but failed") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) | ||||
| 	if animal.Counter != 2 { | ||||
| 		t.Errorf("Fetch a record from with a non-int primary key should work, but failed") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCustomizedTypePrimaryKey(t *testing.T) { | ||||
| 	type ID uint | ||||
| 	type CustomizedTypePrimaryKey struct { | ||||
| 		ID   ID | ||||
| 		Name string | ||||
| 	} | ||||
| 
 | ||||
| 	DB.AutoMigrate(&CustomizedTypePrimaryKey{}) | ||||
| 
 | ||||
| 	p1 := CustomizedTypePrimaryKey{Name: "p1"} | ||||
| 	p2 := CustomizedTypePrimaryKey{Name: "p2"} | ||||
| 	p3 := CustomizedTypePrimaryKey{Name: "p3"} | ||||
| 	DB.Create(&p1) | ||||
| 	DB.Create(&p2) | ||||
| 	DB.Create(&p3) | ||||
| 
 | ||||
| 	var p CustomizedTypePrimaryKey | ||||
| 
 | ||||
| 	if err := DB.First(&p, p2.ID).Error; err == nil { | ||||
| 		t.Errorf("Should return error for invalid query condition") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if p.Name != "p2" { | ||||
| 		t.Errorf("Should find correct value when querying with customized type for primary key") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { | ||||
| 	type AddressByZipCode struct { | ||||
| 		ZipCode string `gorm:"primary_key"` | ||||
| 		Address string | ||||
| 	} | ||||
| 
 | ||||
| 	DB.AutoMigrate(&AddressByZipCode{}) | ||||
| 	DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) | ||||
| 
 | ||||
| 	var address AddressByZipCode | ||||
| 	DB.First(&address, "00501") | ||||
| 	if address.ZipCode != "00501" { | ||||
| 		t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindAsSliceOfPointers(t *testing.T) { | ||||
| 	DB.Save(&User{Name: "user"}) | ||||
| 
 | ||||
| 	var users []User | ||||
| 	DB.Find(&users) | ||||
| 
 | ||||
| 	var userPointers []*User | ||||
| 	DB.Find(&userPointers) | ||||
| 
 | ||||
| 	if len(users) == 0 || len(users) != len(userPointers) { | ||||
| 		t.Errorf("Find slice of pointers") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSearchWithPlainSQL(t *testing.T) { | ||||
| 	user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} | ||||
| 	user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} | ||||
| 	user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 	scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") | ||||
| 
 | ||||
| 	if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("Search with plain SQL") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("Search with plan SQL (regexp)") | ||||
| 	} | ||||
| 
 | ||||
| 	var users []User | ||||
| 	DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) | ||||
| 	if len(users) != 3 { | ||||
| 		t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	scopedb.Where("age <> ?", 20).Find(&users) | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Should found 2 users age != 20, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	scopedb.Where("birthday > ?", "2002-10-10").Find(&users) | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Should found 2 users, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) | ||||
| 	if len(users) != 3 { | ||||
| 		t.Errorf("Should found 3 users, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where("id in (?)", user1.Id).Find(&users) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Should found 1 users, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { | ||||
| 		t.Error("no error should happen when query with empty slice, but got: ", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { | ||||
| 		t.Error("no error should happen when query with empty slice, but got: ", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { | ||||
| 		t.Errorf("Should not get RecordNotFound error when looking for none existing records") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSearchWithTwoDimensionalArray(t *testing.T) { | ||||
| 	var users []User | ||||
| 	user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | ||||
| 	user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | ||||
| 	user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | ||||
| 	DB.Create(&user1) | ||||
| 	DB.Create(&user2) | ||||
| 	DB.Create(&user3) | ||||
| 
 | ||||
| 	if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { | ||||
| 		if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { | ||||
| 			t.Errorf("No error should happen when query with 2D array, but got %v", err) | ||||
| 
 | ||||
| 			if len(users) != 2 { | ||||
| 				t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if dialect := DB.Dialect().GetName(); dialect == "mssql" { | ||||
| 		if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { | ||||
| 			t.Errorf("No error should happen when query with 2D array, but got %v", err) | ||||
| 
 | ||||
| 			if len(users) != 2 { | ||||
| 				t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSearchWithStruct(t *testing.T) { | ||||
| 	user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | ||||
| 	user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | ||||
| 	user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 
 | ||||
| 	if DB.Where(user1.Id).First(&User{}).RecordNotFound() { | ||||
| 		t.Errorf("Search with primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&User{}, user1.Id).RecordNotFound() { | ||||
| 		t.Errorf("Search with primary key as inline condition") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { | ||||
| 		t.Errorf("Search with primary key as inline condition") | ||||
| 	} | ||||
| 
 | ||||
| 	var users []User | ||||
| 	DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) | ||||
| 	if len(users) != 3 { | ||||
| 		t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	var user User | ||||
| 	DB.First(&user, &User{Name: user1.Name}) | ||||
| 	if user.Id == 0 || user.Name != user1.Name { | ||||
| 		t.Errorf("Search first record with inline pointer of struct") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.First(&user, User{Name: user1.Name}) | ||||
| 	if user.Id == 0 || user.Name != user1.Name { | ||||
| 		t.Errorf("Search first record with inline struct") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: user1.Name}).First(&user) | ||||
| 	if user.Id == 0 || user.Name != user1.Name { | ||||
| 		t.Errorf("Search first record with where struct") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Find(&users, &User{Name: user2.Name}) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Search all records with inline struct") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSearchWithMap(t *testing.T) { | ||||
| 	companyID := 1 | ||||
| 	user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | ||||
| 	user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | ||||
| 	user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | ||||
| 	user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) | ||||
| 
 | ||||
| 	var user User | ||||
| 	DB.First(&user, map[string]interface{}{"name": user1.Name}) | ||||
| 	if user.Id == 0 || user.Name != user1.Name { | ||||
| 		t.Errorf("Search first record with inline map") | ||||
| 	} | ||||
| 
 | ||||
| 	user = User{} | ||||
| 	DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) | ||||
| 	if user.Id == 0 || user.Name != user2.Name { | ||||
| 		t.Errorf("Search first record with where map") | ||||
| 	} | ||||
| 
 | ||||
| 	var users []User | ||||
| 	DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Search all records with inline map") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Find(&users, map[string]interface{}{"name": user3.Name}) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Search all records with inline map") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) | ||||
| 	if len(users) != 0 { | ||||
| 		t.Errorf("Search all records with inline map containing null value finding 0 records") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Search all records with inline map containing null value finding 1 record") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) | ||||
| 	if len(users) != 1 { | ||||
| 		t.Errorf("Search all records with inline multiple value map") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSearchWithEmptyChain(t *testing.T) { | ||||
| 	user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} | ||||
| 	user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} | ||||
| 	user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 
 | ||||
| 	if DB.Where("").Where("").First(&User{}).Error != nil { | ||||
| 		t.Errorf("Should not raise any error if searching with empty strings") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { | ||||
| 		t.Errorf("Should not raise any error if searching with empty struct") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { | ||||
| 		t.Errorf("Should not raise any error if searching with empty map") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelect(t *testing.T) { | ||||
| 	user1 := User{Name: "SelectUser1"} | ||||
| 	DB.Save(&user1) | ||||
| 
 | ||||
| 	var user User | ||||
| 	DB.Where("name = ?", user1.Name).Select("name").Find(&user) | ||||
| 	if user.Id != 0 { | ||||
| 		t.Errorf("Should not have ID because only selected name, %+v", user.Id) | ||||
| 	} | ||||
| 
 | ||||
| 	if user.Name != user1.Name { | ||||
| 		t.Errorf("Should have user Name when selected it") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOrderAndPluck(t *testing.T) { | ||||
| 	user1 := User{Name: "OrderPluckUser1", Age: 1} | ||||
| 	user2 := User{Name: "OrderPluckUser2", Age: 10} | ||||
| 	user3 := User{Name: "OrderPluckUser3", Age: 20} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 	scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") | ||||
| 
 | ||||
| 	var user User | ||||
| 	scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user) | ||||
| 	if user.Name != "OrderPluckUser2" { | ||||
| 		t.Errorf("Order with sql expression") | ||||
| 	} | ||||
| 
 | ||||
| 	var ages []int64 | ||||
| 	scopedb.Order("age desc").Pluck("age", &ages) | ||||
| 	if ages[0] != 20 { | ||||
| 		t.Errorf("The first age should be 20 when order with age desc") | ||||
| 	} | ||||
| 
 | ||||
| 	var ages1, ages2 []int64 | ||||
| 	scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) | ||||
| 	if !reflect.DeepEqual(ages1, ages2) { | ||||
| 		t.Errorf("The first order is the primary order") | ||||
| 	} | ||||
| 
 | ||||
| 	var ages3, ages4 []int64 | ||||
| 	scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) | ||||
| 	if reflect.DeepEqual(ages3, ages4) { | ||||
| 		t.Errorf("Reorder should work") | ||||
| 	} | ||||
| 
 | ||||
| 	var names []string | ||||
| 	var ages5 []int64 | ||||
| 	scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) | ||||
| 	if names != nil && ages5 != nil { | ||||
| 		if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { | ||||
| 			t.Errorf("Order with multiple orders") | ||||
| 		} | ||||
| 	} else { | ||||
| 		t.Errorf("Order with multiple orders") | ||||
| 	} | ||||
| 
 | ||||
| 	var ages6 []int64 | ||||
| 	if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil { | ||||
| 		t.Errorf("An empty string as order clause produces invalid queries") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(User{}).Select("name, age").Find(&[]User{}) | ||||
| } | ||||
| 
 | ||||
| func TestLimit(t *testing.T) { | ||||
| 	user1 := User{Name: "LimitUser1", Age: 1} | ||||
| 	user2 := User{Name: "LimitUser2", Age: 10} | ||||
| 	user3 := User{Name: "LimitUser3", Age: 20} | ||||
| 	user4 := User{Name: "LimitUser4", Age: 10} | ||||
| 	user5 := User{Name: "LimitUser5", Age: 20} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) | ||||
| 
 | ||||
| 	var users1, users2, users3 []User | ||||
| 	DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) | ||||
| 
 | ||||
| 	if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { | ||||
| 		t.Errorf("Limit should works") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOffset(t *testing.T) { | ||||
| 	for i := 0; i < 20; i++ { | ||||
| 		DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) | ||||
| 	} | ||||
| 	var users1, users2, users3, users4 []User | ||||
| 	DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) | ||||
| 
 | ||||
| 	if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { | ||||
| 		t.Errorf("Offset should work") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestLimitAndOffsetSQL(t *testing.T) { | ||||
| 	user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10} | ||||
| 	user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20} | ||||
| 	user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30} | ||||
| 	user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40} | ||||
| 	user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50} | ||||
| 	if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		name          string | ||||
| 		limit, offset interface{} | ||||
| 		users         []*User | ||||
| 		ok            bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:   "OK", | ||||
| 			limit:  float64(2), | ||||
| 			offset: float64(2), | ||||
| 			users: []*User{ | ||||
| 				&User{Name: "TestLimitAndOffsetSQL3", Age: 30}, | ||||
| 				&User{Name: "TestLimitAndOffsetSQL2", Age: 20}, | ||||
| 			}, | ||||
| 			ok: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "Limit parse error", | ||||
| 			limit:  float64(1000000), // 1e+06
 | ||||
| 			offset: float64(2), | ||||
| 			ok:     false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "Offset parse error", | ||||
| 			limit:  float64(2), | ||||
| 			offset: float64(1000000), // 1e+06
 | ||||
| 			ok:     false, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			var users []*User | ||||
| 			err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error | ||||
| 			if tt.ok { | ||||
| 				if err != nil { | ||||
| 					t.Errorf("error expected nil, but got %v", err) | ||||
| 				} | ||||
| 				if len(users) != len(tt.users) { | ||||
| 					t.Errorf("users length expected %d, but got %d", len(tt.users), len(users)) | ||||
| 				} | ||||
| 				for i := range tt.users { | ||||
| 					if users[i].Name != tt.users[i].Name { | ||||
| 						t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name) | ||||
| 					} | ||||
| 					if users[i].Age != tt.users[i].Age { | ||||
| 						t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age) | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				if err == nil { | ||||
| 					t.Error("error expected not nil, but got nil") | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOr(t *testing.T) { | ||||
| 	user1 := User{Name: "OrUser1", Age: 1} | ||||
| 	user2 := User{Name: "OrUser2", Age: 10} | ||||
| 	user3 := User{Name: "OrUser3", Age: 20} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 
 | ||||
| 	var users []User | ||||
| 	DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Find users with or") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCount(t *testing.T) { | ||||
| 	user1 := User{Name: "CountUser1", Age: 1} | ||||
| 	user2 := User{Name: "CountUser2", Age: 10} | ||||
| 	user3 := User{Name: "CountUser3", Age: 20} | ||||
| 
 | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 	var count, count1, count2 int64 | ||||
| 	var users []User | ||||
| 
 | ||||
| 	if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { | ||||
| 		t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	if count != int64(len(users)) { | ||||
| 		t.Errorf("Count() method should get correct value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) | ||||
| 	if count1 != 1 || count2 != 3 { | ||||
| 		t.Errorf("Multiple count in chain") | ||||
| 	} | ||||
| 
 | ||||
| 	var count3 int | ||||
| 	if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count3 != 2 { | ||||
| 		t.Errorf("Should get correct count, but got %v", count3) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNot(t *testing.T) { | ||||
| 	DB.Create(getPreparedUser("user1", "not")) | ||||
| 	DB.Create(getPreparedUser("user2", "not")) | ||||
| 	DB.Create(getPreparedUser("user3", "not")) | ||||
| 
 | ||||
| 	user4 := getPreparedUser("user4", "not") | ||||
| 	user4.Company = Company{} | ||||
| 	DB.Create(user4) | ||||
| 
 | ||||
| 	DB := DB.Where("role = ?", "not") | ||||
| 
 | ||||
| 	var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User | ||||
| 	if DB.Find(&users1).RowsAffected != 4 { | ||||
| 		t.Errorf("should find 4 not users") | ||||
| 	} | ||||
| 	DB.Not(users1[0].Id).Find(&users2) | ||||
| 
 | ||||
| 	if len(users1)-len(users2) != 1 { | ||||
| 		t.Errorf("Should ignore the first users with Not") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not([]int{}).Find(&users3) | ||||
| 	if len(users1)-len(users3) != 0 { | ||||
| 		t.Errorf("Should find all users with a blank condition") | ||||
| 	} | ||||
| 
 | ||||
| 	var name3Count int64 | ||||
| 	DB.Table("users").Where("name = ?", "user3").Count(&name3Count) | ||||
| 	DB.Not("name", "user3").Find(&users4) | ||||
| 	if len(users1)-len(users4) != int(name3Count) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not("name = ?", "user3").Find(&users4) | ||||
| 	if len(users1)-len(users4) != int(name3Count) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not("name <> ?", "user3").Find(&users4) | ||||
| 	if len(users4) != int(name3Count) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not(User{Name: "user3"}).Find(&users5) | ||||
| 
 | ||||
| 	if len(users1)-len(users5) != int(name3Count) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) | ||||
| 	if len(users1)-len(users6) != int(name3Count) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) | ||||
| 	if len(users1)-len(users7) != 2 { // not user3 or user4
 | ||||
| 		t.Errorf("Should find all user's name not equal to 3 who do not have company id") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Not("name", []string{"user3"}).Find(&users8) | ||||
| 	if len(users1)-len(users8) != int(name3Count) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| 
 | ||||
| 	var name2Count int64 | ||||
| 	DB.Table("users").Where("name = ?", "user2").Count(&name2Count) | ||||
| 	DB.Not("name", []string{"user3", "user2"}).Find(&users9) | ||||
| 	if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { | ||||
| 		t.Errorf("Should find all users' name not equal 3") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFillSmallerStruct(t *testing.T) { | ||||
| 	user1 := User{Name: "SmallerUser", Age: 100} | ||||
| 	DB.Save(&user1) | ||||
| 	type SimpleUser struct { | ||||
| 		Name      string | ||||
| 		Id        int64 | ||||
| 		UpdatedAt time.Time | ||||
| 		CreatedAt time.Time | ||||
| 	} | ||||
| 
 | ||||
| 	var simpleUser SimpleUser | ||||
| 	DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) | ||||
| 
 | ||||
| 	if simpleUser.Id == 0 || simpleUser.Name == "" { | ||||
| 		t.Errorf("Should fill data correctly into smaller struct") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindOrInitialize(t *testing.T) { | ||||
| 	var user1, user2, user3, user4, user5, user6 User | ||||
| 	DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) | ||||
| 	if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { | ||||
| 		t.Errorf("user should be initialized with search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) | ||||
| 	if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { | ||||
| 		t.Errorf("user should be initialized with search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) | ||||
| 	if user3.Name != "find or init 2" || user3.Id != 0 { | ||||
| 		t.Errorf("user should be initialized with inline search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) | ||||
| 	if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { | ||||
| 		t.Errorf("user should be initialized with search value and attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) | ||||
| 	if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { | ||||
| 		t.Errorf("user should be initialized with search value and assign attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Save(&User{Name: "find or init", Age: 33}) | ||||
| 	DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) | ||||
| 	if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { | ||||
| 		t.Errorf("user should be found and not initialized by Attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) | ||||
| 	if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { | ||||
| 		t.Errorf("user should be found with FirstOrInit") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) | ||||
| 	if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { | ||||
| 		t.Errorf("user should be found and updated with assigned attrs") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindOrCreate(t *testing.T) { | ||||
| 	var user1, user2, user3, user4, user5, user6, user7, user8 User | ||||
| 	DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) | ||||
| 	if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { | ||||
| 		t.Errorf("user should be created with search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) | ||||
| 	if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { | ||||
| 		t.Errorf("user should be created with search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) | ||||
| 	if user3.Name != "find or create 2" || user3.Id == 0 { | ||||
| 		t.Errorf("user should be created with inline search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) | ||||
| 	if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { | ||||
| 		t.Errorf("user should be created with search value and attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	updatedAt1 := user4.UpdatedAt | ||||
| 	DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) | ||||
| 	if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("UpdateAt should be changed when update values with assign") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) | ||||
| 	if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { | ||||
| 		t.Errorf("user should be created with search value and assigned attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) | ||||
| 	if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { | ||||
| 		t.Errorf("user should be found and not initialized by Attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) | ||||
| 	if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { | ||||
| 		t.Errorf("user should be found and updated with assigned attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create"}).Find(&user7) | ||||
| 	if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { | ||||
| 		t.Errorf("user should be found and updated with assigned attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) | ||||
| 	if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { | ||||
| 		t.Errorf("embedded struct email should be saved") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { | ||||
| 		t.Errorf("embedded struct credit card should be saved") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithEscapedFieldName(t *testing.T) { | ||||
| 	user1 := User{Name: "EscapedFieldNameUser", Age: 1} | ||||
| 	user2 := User{Name: "EscapedFieldNameUser", Age: 10} | ||||
| 	user3 := User{Name: "EscapedFieldNameUser", Age: 20} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 
 | ||||
| 	var names []string | ||||
| 	DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) | ||||
| 
 | ||||
| 	if len(names) != 3 { | ||||
| 		t.Errorf("Expected 3 name, but got: %d", len(names)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithVariables(t *testing.T) { | ||||
| 	DB.Save(&User{Name: "jinzhu"}) | ||||
| 
 | ||||
| 	rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() | ||||
| 
 | ||||
| 	if !rows.Next() { | ||||
| 		t.Errorf("Should have returned at least one row") | ||||
| 	} else { | ||||
| 		columns, _ := rows.Columns() | ||||
| 		if !reflect.DeepEqual(columns, []string{"fake"}) { | ||||
| 			t.Errorf("Should only contains one column") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	rows.Close() | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithArrayInput(t *testing.T) { | ||||
| 	DB.Save(&User{Name: "jinzhu", Age: 42}) | ||||
| 
 | ||||
| 	var user User | ||||
| 	DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) | ||||
| 
 | ||||
| 	if user.Name != "jinzhu" || user.Age != 42 { | ||||
| 		t.Errorf("Should have selected both age and name") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPluckWithSelect(t *testing.T) { | ||||
| 	var ( | ||||
| 		user              = User{Name: "matematik7_pluck_with_select", Age: 25} | ||||
| 		combinedName      = fmt.Sprintf("%v%v", user.Name, user.Age) | ||||
| 		combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) | ||||
| 	) | ||||
| 
 | ||||
| 	if dialect := DB.Dialect().GetName(); dialect == "sqlite3" { | ||||
| 		combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Save(&user) | ||||
| 
 | ||||
| 	selectStr := combineUserAgeSQL + " as user_age" | ||||
| 	var userAges []string | ||||
| 	err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(userAges) != 1 || userAges[0] != combinedName { | ||||
| 		t.Errorf("Should correctly pluck with select, got: %s", userAges) | ||||
| 	} | ||||
| 
 | ||||
| 	selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) | ||||
| 	userAges = userAges[:0] | ||||
| 	err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(userAges) != 1 || userAges[0] != combinedName { | ||||
| 		t.Errorf("Should correctly pluck with select, got: %s", userAges) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										139
									
								
								scaner_test.go
									
									
									
									
									
								
							
							
						
						
									
										139
									
								
								scaner_test.go
									
									
									
									
									
								
							| @ -1,139 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func TestScannableSlices(t *testing.T) { | ||||
| 	if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { | ||||
| 		t.Errorf("Should create table with slice values correctly: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	r1 := RecordWithSlice{ | ||||
| 		Strings: ExampleStringSlice{"a", "b", "c"}, | ||||
| 		Structs: ExampleStructSlice{ | ||||
| 			{"name1", "value1"}, | ||||
| 			{"name2", "value2"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Save(&r1).Error; err != nil { | ||||
| 		t.Errorf("Should save record with slice values") | ||||
| 	} | ||||
| 
 | ||||
| 	var r2 RecordWithSlice | ||||
| 
 | ||||
| 	if err := DB.Find(&r2).Error; err != nil { | ||||
| 		t.Errorf("Should fetch record with slice values") | ||||
| 	} | ||||
| 
 | ||||
| 	if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { | ||||
| 		t.Errorf("Should have serialised and deserialised a string array") | ||||
| 	} | ||||
| 
 | ||||
| 	if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { | ||||
| 		t.Errorf("Should have serialised and deserialised a struct array") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type RecordWithSlice struct { | ||||
| 	ID      uint64 | ||||
| 	Strings ExampleStringSlice `sql:"type:text"` | ||||
| 	Structs ExampleStructSlice `sql:"type:text"` | ||||
| } | ||||
| 
 | ||||
| type ExampleStringSlice []string | ||||
| 
 | ||||
| func (l ExampleStringSlice) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(l) | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| func (l *ExampleStringSlice) Scan(input interface{}) error { | ||||
| 	switch value := input.(type) { | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), l) | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, l) | ||||
| 	default: | ||||
| 		return errors.New("not supported") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ExampleStruct struct { | ||||
| 	Name  string | ||||
| 	Value string | ||||
| } | ||||
| 
 | ||||
| type ExampleStructSlice []ExampleStruct | ||||
| 
 | ||||
| func (l ExampleStructSlice) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(l) | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| func (l *ExampleStructSlice) Scan(input interface{}) error { | ||||
| 	switch value := input.(type) { | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), l) | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, l) | ||||
| 	default: | ||||
| 		return errors.New("not supported") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ScannerDataType struct { | ||||
| 	Street string `sql:"TYPE:varchar(24)"` | ||||
| } | ||||
| 
 | ||||
| func (ScannerDataType) Value() (driver.Value, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (*ScannerDataType) Scan(input interface{}) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type ScannerDataTypeTestStruct struct { | ||||
| 	Field1          int | ||||
| 	ScannerDataType *ScannerDataType `sql:"TYPE:json"` | ||||
| } | ||||
| 
 | ||||
| type ScannerDataType2 struct { | ||||
| 	Street string `sql:"TYPE:varchar(24)"` | ||||
| } | ||||
| 
 | ||||
| func (ScannerDataType2) Value() (driver.Value, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (*ScannerDataType2) Scan(input interface{}) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type ScannerDataTypeTestStruct2 struct { | ||||
| 	Field1          int | ||||
| 	ScannerDataType *ScannerDataType2 | ||||
| } | ||||
| 
 | ||||
| func TestScannerDataType(t *testing.T) { | ||||
| 	scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}} | ||||
| 	if field, ok := scope.FieldByName("ScannerDataType"); ok { | ||||
| 		if DB.Dialect().DataTypeOf(field.StructField) != "json" { | ||||
| 			t.Errorf("data type for scanner is wrong") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}} | ||||
| 	if field, ok := scope.FieldByName("ScannerDataType"); ok { | ||||
| 		if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" { | ||||
| 			t.Errorf("data type for scanner is wrong") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,93 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/hex" | ||||
| 	"math/rand" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func NameIn1And2(d *gorm.DB) *gorm.DB { | ||||
| 	return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) | ||||
| } | ||||
| 
 | ||||
| func NameIn2And3(d *gorm.DB) *gorm.DB { | ||||
| 	return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) | ||||
| } | ||||
| 
 | ||||
| func NameIn(names []string) func(d *gorm.DB) *gorm.DB { | ||||
| 	return func(d *gorm.DB) *gorm.DB { | ||||
| 		return d.Where("name in (?)", names) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestScopes(t *testing.T) { | ||||
| 	user1 := User{Name: "ScopeUser1", Age: 1} | ||||
| 	user2 := User{Name: "ScopeUser2", Age: 1} | ||||
| 	user3 := User{Name: "ScopeUser3", Age: 2} | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 
 | ||||
| 	var users1, users2, users3 []User | ||||
| 	DB.Scopes(NameIn1And2).Find(&users1) | ||||
| 	if len(users1) != 2 { | ||||
| 		t.Errorf("Should found two users's name in 1, 2") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) | ||||
| 	if len(users2) != 1 { | ||||
| 		t.Errorf("Should found one user's name is 2") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) | ||||
| 	if len(users3) != 2 { | ||||
| 		t.Errorf("Should found two users's name in 1, 3") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func randName() string { | ||||
| 	data := make([]byte, 8) | ||||
| 	rand.Read(data) | ||||
| 
 | ||||
| 	return "n-" + hex.EncodeToString(data) | ||||
| } | ||||
| 
 | ||||
| func TestValuer(t *testing.T) { | ||||
| 	name := randName() | ||||
| 
 | ||||
| 	origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} | ||||
| 	if err := DB.Save(&origUser).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when saving user, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var user2 User | ||||
| 	if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when querying user with valuer, but got %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFailedValuer(t *testing.T) { | ||||
| 	name := randName() | ||||
| 
 | ||||
| 	err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		t.Errorf("There should be an error should happen when insert data") | ||||
| 	} else if !strings.HasPrefix(err.Error(), "Should not start with") { | ||||
| 		t.Errorf("The error should be returned from Valuer, but get %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDropTableWithTableOptions(t *testing.T) { | ||||
| 	type UserWithOptions struct { | ||||
| 		gorm.Model | ||||
| 	} | ||||
| 	DB.AutoMigrate(&UserWithOptions{}) | ||||
| 
 | ||||
| 	DB = DB.Set("gorm:table_options", "CHARSET=utf8") | ||||
| 	err := DB.DropTable(&UserWithOptions{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Table must be dropped, got error %s", err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										153
									
								
								search.go
									
									
									
									
									
								
							
							
						
						
									
										153
									
								
								search.go
									
									
									
									
									
								
							| @ -1,153 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| type search struct { | ||||
| 	db               *DB | ||||
| 	whereConditions  []map[string]interface{} | ||||
| 	orConditions     []map[string]interface{} | ||||
| 	notConditions    []map[string]interface{} | ||||
| 	havingConditions []map[string]interface{} | ||||
| 	joinConditions   []map[string]interface{} | ||||
| 	initAttrs        []interface{} | ||||
| 	assignAttrs      []interface{} | ||||
| 	selects          map[string]interface{} | ||||
| 	omits            []string | ||||
| 	orders           []interface{} | ||||
| 	preload          []searchPreload | ||||
| 	offset           interface{} | ||||
| 	limit            interface{} | ||||
| 	group            string | ||||
| 	tableName        string | ||||
| 	raw              bool | ||||
| 	Unscoped         bool | ||||
| 	ignoreOrderQuery bool | ||||
| } | ||||
| 
 | ||||
| type searchPreload struct { | ||||
| 	schema     string | ||||
| 	conditions []interface{} | ||||
| } | ||||
| 
 | ||||
| func (s *search) clone() *search { | ||||
| 	clone := *s | ||||
| 	return &clone | ||||
| } | ||||
| 
 | ||||
| func (s *search) Where(query interface{}, values ...interface{}) *search { | ||||
| 	s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Not(query interface{}, values ...interface{}) *search { | ||||
| 	s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Or(query interface{}, values ...interface{}) *search { | ||||
| 	s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Attrs(attrs ...interface{}) *search { | ||||
| 	s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Assign(attrs ...interface{}) *search { | ||||
| 	s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Order(value interface{}, reorder ...bool) *search { | ||||
| 	if len(reorder) > 0 && reorder[0] { | ||||
| 		s.orders = []interface{}{} | ||||
| 	} | ||||
| 
 | ||||
| 	if value != nil && value != "" { | ||||
| 		s.orders = append(s.orders, value) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Select(query interface{}, args ...interface{}) *search { | ||||
| 	s.selects = map[string]interface{}{"query": query, "args": args} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Omit(columns ...string) *search { | ||||
| 	s.omits = columns | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Limit(limit interface{}) *search { | ||||
| 	s.limit = limit | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Offset(offset interface{}) *search { | ||||
| 	s.offset = offset | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Group(query string) *search { | ||||
| 	s.group = s.getInterfaceAsSQL(query) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Having(query interface{}, values ...interface{}) *search { | ||||
| 	if val, ok := query.(*SqlExpr); ok { | ||||
| 		s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) | ||||
| 	} else { | ||||
| 		s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Joins(query string, values ...interface{}) *search { | ||||
| 	s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Preload(schema string, values ...interface{}) *search { | ||||
| 	var preloads []searchPreload | ||||
| 	for _, preload := range s.preload { | ||||
| 		if preload.schema != schema { | ||||
| 			preloads = append(preloads, preload) | ||||
| 		} | ||||
| 	} | ||||
| 	preloads = append(preloads, searchPreload{schema, values}) | ||||
| 	s.preload = preloads | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Raw(b bool) *search { | ||||
| 	s.raw = b | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) unscoped() *search { | ||||
| 	s.Unscoped = true | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) Table(name string) *search { | ||||
| 	s.tableName = name | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *search) getInterfaceAsSQL(value interface{}) (str string) { | ||||
| 	switch value.(type) { | ||||
| 	case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: | ||||
| 		str = fmt.Sprintf("%v", value) | ||||
| 	default: | ||||
| 		s.db.AddError(ErrInvalidSQL) | ||||
| 	} | ||||
| 
 | ||||
| 	if str == "-1" { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| @ -1,30 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestCloneSearch(t *testing.T) { | ||||
| 	s := new(search) | ||||
| 	s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") | ||||
| 
 | ||||
| 	s1 := s.clone() | ||||
| 	s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") | ||||
| 
 | ||||
| 	if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { | ||||
| 		t.Errorf("Where should be copied") | ||||
| 	} | ||||
| 
 | ||||
| 	if reflect.DeepEqual(s.orders, s1.orders) { | ||||
| 		t.Errorf("Order should be copied") | ||||
| 	} | ||||
| 
 | ||||
| 	if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { | ||||
| 		t.Errorf("InitAttrs should be copied") | ||||
| 	} | ||||
| 
 | ||||
| 	if reflect.DeepEqual(s.Select, s1.Select) { | ||||
| 		t.Errorf("selectStr should be copied") | ||||
| 	} | ||||
| } | ||||
| @ -1,5 +0,0 @@ | ||||
| dialects=("postgres" "mysql" "mssql" "sqlite") | ||||
| 
 | ||||
| for dialect in "${dialects[@]}" ; do | ||||
|     DEBUG=false GORM_DIALECT=${dialect} go test | ||||
| done | ||||
							
								
								
									
										465
									
								
								update_test.go
									
									
									
									
									
								
							
							
						
						
									
										465
									
								
								update_test.go
									
									
									
									
									
								
							| @ -1,465 +0,0 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func TestUpdate(t *testing.T) { | ||||
| 	product1 := Product{Code: "product1code"} | ||||
| 	product2 := Product{Code: "product2code"} | ||||
| 
 | ||||
| 	DB.Save(&product1).Save(&product2).Update("code", "product2newcode") | ||||
| 
 | ||||
| 	if product2.Code != "product2newcode" { | ||||
| 		t.Errorf("Record should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.First(&product1, product1.Id) | ||||
| 	DB.First(&product2, product2.Id) | ||||
| 	updatedAt1 := product1.UpdatedAt | ||||
| 
 | ||||
| 	if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { | ||||
| 		t.Errorf("Product1 should not be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { | ||||
| 		t.Errorf("Product2's code should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { | ||||
| 		t.Errorf("Product2's code should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") | ||||
| 
 | ||||
| 	var product4 Product | ||||
| 	DB.First(&product4, product1.Id) | ||||
| 	if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("updatedAt should be updated if something changed") | ||||
| 	} | ||||
| 
 | ||||
| 	if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { | ||||
| 		t.Errorf("Product1's code should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { | ||||
| 		t.Errorf("Product should not be changed to 789") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { | ||||
| 		t.Error("No error should raise when update with CamelCase") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { | ||||
| 		t.Error("No error should raise when update_column with CamelCase") | ||||
| 	} | ||||
| 
 | ||||
| 	var products []Product | ||||
| 	DB.Find(&products) | ||||
| 	if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { | ||||
| 		t.Error("RowsAffected should be correct when do batch update") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.First(&product4, product4.Id) | ||||
| 	updatedAt4 := product4.UpdatedAt | ||||
| 	DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) | ||||
| 	var product5 Product | ||||
| 	DB.First(&product5, product4.Id) | ||||
| 	if product5.Price != product4.Price+100-50 { | ||||
| 		t.Errorf("Update with expression") | ||||
| 	} | ||||
| 	if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("Update with expression should update UpdatedAt") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { | ||||
| 	animal := Animal{Name: "Ferdinand"} | ||||
| 	DB.Save(&animal) | ||||
| 	updatedAt1 := animal.UpdatedAt | ||||
| 
 | ||||
| 	DB.Save(&animal).Update("name", "Francis") | ||||
| 
 | ||||
| 	if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("updatedAt should not be updated if nothing changed") | ||||
| 	} | ||||
| 
 | ||||
| 	var animals []Animal | ||||
| 	DB.Find(&animals) | ||||
| 	if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { | ||||
| 		t.Error("RowsAffected should be correct when do batch update") | ||||
| 	} | ||||
| 
 | ||||
| 	animal = Animal{From: "somewhere"}              // No name fields, should be filled with the default value (galeone)
 | ||||
| 	DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
 | ||||
| 	DB.First(&animal, animal.Counter) | ||||
| 	if animal.Name != "galeone" { | ||||
| 		t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) | ||||
| 	} | ||||
| 
 | ||||
| 	// When changing a field with a default value, the change must occur
 | ||||
| 	animal.Name = "amazing horse" | ||||
| 	DB.Save(&animal) | ||||
| 	DB.First(&animal, animal.Counter) | ||||
| 	if animal.Name != "amazing horse" { | ||||
| 		t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) | ||||
| 	} | ||||
| 
 | ||||
| 	// When changing a field with a default value with blank value
 | ||||
| 	animal.Name = "" | ||||
| 	DB.Save(&animal) | ||||
| 	DB.First(&animal, animal.Counter) | ||||
| 	if animal.Name != "" { | ||||
| 		t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdates(t *testing.T) { | ||||
| 	product1 := Product{Code: "product1code", Price: 10} | ||||
| 	product2 := Product{Code: "product2code", Price: 10} | ||||
| 	DB.Save(&product1).Save(&product2) | ||||
| 	DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) | ||||
| 	if product1.Code != "product1newcode" || product1.Price != 100 { | ||||
| 		t.Errorf("Record should be updated also with map") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.First(&product1, product1.Id) | ||||
| 	DB.First(&product2, product2.Id) | ||||
| 	updatedAt2 := product2.UpdatedAt | ||||
| 
 | ||||
| 	if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { | ||||
| 		t.Errorf("Product2 should not be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { | ||||
| 		t.Errorf("Product1 should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) | ||||
| 	if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { | ||||
| 		t.Errorf("Product2's code should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	var product4 Product | ||||
| 	DB.First(&product4, product2.Id) | ||||
| 	if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("updatedAt should be updated if something changed") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { | ||||
| 		t.Errorf("product2's code should be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	updatedAt4 := product4.UpdatedAt | ||||
| 	DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) | ||||
| 	var product5 Product | ||||
| 	DB.First(&product5, product4.Id) | ||||
| 	if product5.Price != product4.Price+100 { | ||||
| 		t.Errorf("Updates with expression") | ||||
| 	} | ||||
| 	// product4's UpdatedAt will be reset when updating
 | ||||
| 	if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("Updates with expression should update UpdatedAt") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdateColumn(t *testing.T) { | ||||
| 	product1 := Product{Code: "product1code", Price: 10} | ||||
| 	product2 := Product{Code: "product2code", Price: 20} | ||||
| 	DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) | ||||
| 	if product2.Code != "product2newcode" || product2.Price != 100 { | ||||
| 		t.Errorf("product 2 should be updated with update column") | ||||
| 	} | ||||
| 
 | ||||
| 	var product3 Product | ||||
| 	DB.First(&product3, product1.Id) | ||||
| 	if product3.Code != "product1code" || product3.Price != 10 { | ||||
| 		t.Errorf("product 1 should not be updated") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.First(&product2, product2.Id) | ||||
| 	updatedAt2 := product2.UpdatedAt | ||||
| 	DB.Model(product2).UpdateColumn("code", "update_column_new") | ||||
| 	var product4 Product | ||||
| 	DB.First(&product4, product2.Id) | ||||
| 	if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("updatedAt should not be updated with update column") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) | ||||
| 	var product5 Product | ||||
| 	DB.First(&product5, product4.Id) | ||||
| 	if product5.Price != product4.Price+100-50 { | ||||
| 		t.Errorf("UpdateColumn with expression") | ||||
| 	} | ||||
| 	if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("UpdateColumn with expression should not update UpdatedAt") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithUpdate(t *testing.T) { | ||||
| 	user := getPreparedUser("select_user", "select_with_update") | ||||
| 	DB.Create(user) | ||||
| 
 | ||||
| 	var reloadUser User | ||||
| 	DB.First(&reloadUser, user.Id) | ||||
| 	reloadUser.Name = "new_name" | ||||
| 	reloadUser.Age = 50 | ||||
| 	reloadUser.BillingAddress = Address{Address1: "New Billing Address"} | ||||
| 	reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} | ||||
| 	reloadUser.CreditCard = CreditCard{Number: "987654321"} | ||||
| 	reloadUser.Emails = []Email{ | ||||
| 		{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | ||||
| 	} | ||||
| 	reloadUser.Company = Company{Name: "new company"} | ||||
| 
 | ||||
| 	DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) | ||||
| 
 | ||||
| 	var queryUser User | ||||
| 	DB.Preload("BillingAddress").Preload("ShippingAddress"). | ||||
| 		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | ||||
| 
 | ||||
| 	if queryUser.Name == user.Name || queryUser.Age != user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| 
 | ||||
| 	if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || | ||||
| 		queryUser.ShippingAddressId != user.ShippingAddressId || | ||||
| 		queryUser.CreditCard.ID == user.CreditCard.ID || | ||||
| 		len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { | ||||
| 		t.Errorf("Should only update selected relationships") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithUpdateWithMap(t *testing.T) { | ||||
| 	user := getPreparedUser("select_user", "select_with_update_map") | ||||
| 	DB.Create(user) | ||||
| 
 | ||||
| 	updateValues := map[string]interface{}{ | ||||
| 		"Name":            "new_name", | ||||
| 		"Age":             50, | ||||
| 		"BillingAddress":  Address{Address1: "New Billing Address"}, | ||||
| 		"ShippingAddress": Address{Address1: "New ShippingAddress Address"}, | ||||
| 		"CreditCard":      CreditCard{Number: "987654321"}, | ||||
| 		"Emails": []Email{ | ||||
| 			{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | ||||
| 		}, | ||||
| 		"Company": Company{Name: "new company"}, | ||||
| 	} | ||||
| 
 | ||||
| 	var reloadUser User | ||||
| 	DB.First(&reloadUser, user.Id) | ||||
| 	DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) | ||||
| 
 | ||||
| 	var queryUser User | ||||
| 	DB.Preload("BillingAddress").Preload("ShippingAddress"). | ||||
| 		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | ||||
| 
 | ||||
| 	if queryUser.Name == user.Name || queryUser.Age != user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| 
 | ||||
| 	if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || | ||||
| 		queryUser.ShippingAddressId != user.ShippingAddressId || | ||||
| 		queryUser.CreditCard.ID == user.CreditCard.ID || | ||||
| 		len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { | ||||
| 		t.Errorf("Should only update selected relationships") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithUpdate(t *testing.T) { | ||||
| 	user := getPreparedUser("omit_user", "omit_with_update") | ||||
| 	DB.Create(user) | ||||
| 
 | ||||
| 	var reloadUser User | ||||
| 	DB.First(&reloadUser, user.Id) | ||||
| 	reloadUser.Name = "new_name" | ||||
| 	reloadUser.Age = 50 | ||||
| 	reloadUser.BillingAddress = Address{Address1: "New Billing Address"} | ||||
| 	reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} | ||||
| 	reloadUser.CreditCard = CreditCard{Number: "987654321"} | ||||
| 	reloadUser.Emails = []Email{ | ||||
| 		{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | ||||
| 	} | ||||
| 	reloadUser.Company = Company{Name: "new company"} | ||||
| 
 | ||||
| 	DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) | ||||
| 
 | ||||
| 	var queryUser User | ||||
| 	DB.Preload("BillingAddress").Preload("ShippingAddress"). | ||||
| 		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | ||||
| 
 | ||||
| 	if queryUser.Name != user.Name || queryUser.Age == user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| 
 | ||||
| 	if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || | ||||
| 		queryUser.ShippingAddressId == user.ShippingAddressId || | ||||
| 		queryUser.CreditCard.ID != user.CreditCard.ID || | ||||
| 		len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { | ||||
| 		t.Errorf("Should only update relationships that not omitted") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithUpdateWithMap(t *testing.T) { | ||||
| 	user := getPreparedUser("select_user", "select_with_update_map") | ||||
| 	DB.Create(user) | ||||
| 
 | ||||
| 	updateValues := map[string]interface{}{ | ||||
| 		"Name":            "new_name", | ||||
| 		"Age":             50, | ||||
| 		"BillingAddress":  Address{Address1: "New Billing Address"}, | ||||
| 		"ShippingAddress": Address{Address1: "New ShippingAddress Address"}, | ||||
| 		"CreditCard":      CreditCard{Number: "987654321"}, | ||||
| 		"Emails": []Email{ | ||||
| 			{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, | ||||
| 		}, | ||||
| 		"Company": Company{Name: "new company"}, | ||||
| 	} | ||||
| 
 | ||||
| 	var reloadUser User | ||||
| 	DB.First(&reloadUser, user.Id) | ||||
| 	DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) | ||||
| 
 | ||||
| 	var queryUser User | ||||
| 	DB.Preload("BillingAddress").Preload("ShippingAddress"). | ||||
| 		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) | ||||
| 
 | ||||
| 	if queryUser.Name != user.Name || queryUser.Age == user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| 
 | ||||
| 	if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || | ||||
| 		queryUser.ShippingAddressId == user.ShippingAddressId || | ||||
| 		queryUser.CreditCard.ID != user.CreditCard.ID || | ||||
| 		len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { | ||||
| 		t.Errorf("Should only update relationships not omitted") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithUpdateColumn(t *testing.T) { | ||||
| 	user := getPreparedUser("select_user", "select_with_update_map") | ||||
| 	DB.Create(user) | ||||
| 
 | ||||
| 	updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} | ||||
| 
 | ||||
| 	var reloadUser User | ||||
| 	DB.First(&reloadUser, user.Id) | ||||
| 	DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) | ||||
| 
 | ||||
| 	var queryUser User | ||||
| 	DB.First(&queryUser, user.Id) | ||||
| 
 | ||||
| 	if queryUser.Name == user.Name || queryUser.Age != user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithUpdateColumn(t *testing.T) { | ||||
| 	user := getPreparedUser("select_user", "select_with_update_map") | ||||
| 	DB.Create(user) | ||||
| 
 | ||||
| 	updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} | ||||
| 
 | ||||
| 	var reloadUser User | ||||
| 	DB.First(&reloadUser, user.Id) | ||||
| 	DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) | ||||
| 
 | ||||
| 	var queryUser User | ||||
| 	DB.First(&queryUser, user.Id) | ||||
| 
 | ||||
| 	if queryUser.Name != user.Name || queryUser.Age == user.Age { | ||||
| 		t.Errorf("Should omit name column when update user") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdateColumnsSkipsAssociations(t *testing.T) { | ||||
| 	user := getPreparedUser("update_columns_user", "special_role") | ||||
| 	user.Age = 99 | ||||
| 	address1 := "first street" | ||||
| 	user.BillingAddress = Address{Address1: address1} | ||||
| 	DB.Save(user) | ||||
| 
 | ||||
| 	// Update a single field of the user and verify that the changed address is not stored.
 | ||||
| 	newAge := int64(100) | ||||
| 	user.BillingAddress.Address1 = "second street" | ||||
| 	db := DB.Model(user).UpdateColumns(User{Age: newAge}) | ||||
| 	if db.RowsAffected != 1 { | ||||
| 		t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) | ||||
| 	} | ||||
| 
 | ||||
| 	// Verify that Age now=`newAge`.
 | ||||
| 	freshUser := &User{Id: user.Id} | ||||
| 	DB.First(freshUser) | ||||
| 	if freshUser.Age != newAge { | ||||
| 		t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) | ||||
| 	} | ||||
| 
 | ||||
| 	// Verify that user's BillingAddress.Address1 is not changed and is still "first street".
 | ||||
| 	DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) | ||||
| 	if freshUser.BillingAddress.Address1 != address1 { | ||||
| 		t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdatesWithBlankValues(t *testing.T) { | ||||
| 	product := Product{Code: "product1", Price: 10} | ||||
| 	DB.Save(&product) | ||||
| 
 | ||||
| 	DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) | ||||
| 
 | ||||
| 	var product1 Product | ||||
| 	DB.First(&product1, product.Id) | ||||
| 
 | ||||
| 	if product1.Code != "product1" || product1.Price != 100 { | ||||
| 		t.Errorf("product's code should not be updated") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ElementWithIgnoredField struct { | ||||
| 	Id           int64 | ||||
| 	Value        string | ||||
| 	IgnoredField int64 `sql:"-"` | ||||
| } | ||||
| 
 | ||||
| func (e ElementWithIgnoredField) TableName() string { | ||||
| 	return "element_with_ignored_field" | ||||
| } | ||||
| 
 | ||||
| func TestUpdatesTableWithIgnoredValues(t *testing.T) { | ||||
| 	elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} | ||||
| 	DB.Save(&elem) | ||||
| 
 | ||||
| 	DB.Table(elem.TableName()). | ||||
| 		Where("id = ?", elem.Id). | ||||
| 		// DB.Model(&ElementWithIgnoredField{Id: elem.Id}).
 | ||||
| 		Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) | ||||
| 
 | ||||
| 	var elem1 ElementWithIgnoredField | ||||
| 	err := DB.First(&elem1, elem.Id).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("error getting an element from database: %s", err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	if elem1.IgnoredField != 0 { | ||||
| 		t.Errorf("element's ignored field should not be updated") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdateDecodeVirtualAttributes(t *testing.T) { | ||||
| 	var user = User{ | ||||
| 		Name:     "jinzhu", | ||||
| 		IgnoreMe: 88, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Save(&user) | ||||
| 
 | ||||
| 	DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) | ||||
| 
 | ||||
| 	if user.IgnoreMe != 100 { | ||||
| 		t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										226
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										226
									
								
								utils.go
									
									
									
									
									
								
							| @ -1,226 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // NowFunc returns current time, this function is exported in order to be able
 | ||||
| // to give the flexibility to the developer to customize it according to their
 | ||||
| // needs, e.g:
 | ||||
| //    gorm.NowFunc = func() time.Time {
 | ||||
| //      return time.Now().UTC()
 | ||||
| //    }
 | ||||
| var NowFunc = func() time.Time { | ||||
| 	return time.Now() | ||||
| } | ||||
| 
 | ||||
| // Copied from golint
 | ||||
| var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} | ||||
| var commonInitialismsReplacer *strings.Replacer | ||||
| 
 | ||||
| var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) | ||||
| var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) | ||||
| 
 | ||||
| func init() { | ||||
| 	var commonInitialismsForReplacer []string | ||||
| 	for _, initialism := range commonInitialisms { | ||||
| 		commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) | ||||
| 	} | ||||
| 	commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) | ||||
| } | ||||
| 
 | ||||
| type safeMap struct { | ||||
| 	m map[string]string | ||||
| 	l *sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| func (s *safeMap) Set(key string, value string) { | ||||
| 	s.l.Lock() | ||||
| 	defer s.l.Unlock() | ||||
| 	s.m[key] = value | ||||
| } | ||||
| 
 | ||||
| func (s *safeMap) Get(key string) string { | ||||
| 	s.l.RLock() | ||||
| 	defer s.l.RUnlock() | ||||
| 	return s.m[key] | ||||
| } | ||||
| 
 | ||||
| func newSafeMap() *safeMap { | ||||
| 	return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} | ||||
| } | ||||
| 
 | ||||
| // SQL expression
 | ||||
| type SqlExpr struct { | ||||
| 	expr string | ||||
| 	args []interface{} | ||||
| } | ||||
| 
 | ||||
| // Expr generate raw SQL expression, for example:
 | ||||
| //     DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
 | ||||
| func Expr(expression string, args ...interface{}) *SqlExpr { | ||||
| 	return &SqlExpr{expr: expression, args: args} | ||||
| } | ||||
| 
 | ||||
| func indirect(reflectValue reflect.Value) reflect.Value { | ||||
| 	for reflectValue.Kind() == reflect.Ptr { | ||||
| 		reflectValue = reflectValue.Elem() | ||||
| 	} | ||||
| 	return reflectValue | ||||
| } | ||||
| 
 | ||||
| func toQueryMarks(primaryValues [][]interface{}) string { | ||||
| 	var results []string | ||||
| 
 | ||||
| 	for _, primaryValue := range primaryValues { | ||||
| 		var marks []string | ||||
| 		for range primaryValue { | ||||
| 			marks = append(marks, "?") | ||||
| 		} | ||||
| 
 | ||||
| 		if len(marks) > 1 { | ||||
| 			results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) | ||||
| 		} else { | ||||
| 			results = append(results, strings.Join(marks, "")) | ||||
| 		} | ||||
| 	} | ||||
| 	return strings.Join(results, ",") | ||||
| } | ||||
| 
 | ||||
| func toQueryCondition(scope *Scope, columns []string) string { | ||||
| 	var newColumns []string | ||||
| 	for _, column := range columns { | ||||
| 		newColumns = append(newColumns, scope.Quote(column)) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(columns) > 1 { | ||||
| 		return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) | ||||
| 	} | ||||
| 	return strings.Join(newColumns, ",") | ||||
| } | ||||
| 
 | ||||
| func toQueryValues(values [][]interface{}) (results []interface{}) { | ||||
| 	for _, value := range values { | ||||
| 		for _, v := range value { | ||||
| 			results = append(results, v) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func fileWithLineNum() string { | ||||
| 	for i := 2; i < 15; i++ { | ||||
| 		_, file, line, ok := runtime.Caller(i) | ||||
| 		if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { | ||||
| 			return fmt.Sprintf("%v:%v", file, line) | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func isBlank(value reflect.Value) bool { | ||||
| 	switch value.Kind() { | ||||
| 	case reflect.String: | ||||
| 		return value.Len() == 0 | ||||
| 	case reflect.Bool: | ||||
| 		return !value.Bool() | ||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 		return value.Int() == 0 | ||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | ||||
| 		return value.Uint() == 0 | ||||
| 	case reflect.Float32, reflect.Float64: | ||||
| 		return value.Float() == 0 | ||||
| 	case reflect.Interface, reflect.Ptr: | ||||
| 		return value.IsNil() | ||||
| 	} | ||||
| 
 | ||||
| 	return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) | ||||
| } | ||||
| 
 | ||||
| func toSearchableMap(attrs ...interface{}) (result interface{}) { | ||||
| 	if len(attrs) > 1 { | ||||
| 		if str, ok := attrs[0].(string); ok { | ||||
| 			result = map[string]interface{}{str: attrs[1]} | ||||
| 		} | ||||
| 	} else if len(attrs) == 1 { | ||||
| 		if attr, ok := attrs[0].(map[string]interface{}); ok { | ||||
| 			result = attr | ||||
| 		} | ||||
| 
 | ||||
| 		if attr, ok := attrs[0].(interface{}); ok { | ||||
| 			result = attr | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func equalAsString(a interface{}, b interface{}) bool { | ||||
| 	return toString(a) == toString(b) | ||||
| } | ||||
| 
 | ||||
| func toString(str interface{}) string { | ||||
| 	if values, ok := str.([]interface{}); ok { | ||||
| 		var results []string | ||||
| 		for _, value := range values { | ||||
| 			results = append(results, toString(value)) | ||||
| 		} | ||||
| 		return strings.Join(results, "_") | ||||
| 	} else if bytes, ok := str.([]byte); ok { | ||||
| 		return string(bytes) | ||||
| 	} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { | ||||
| 		return fmt.Sprintf("%v", reflectValue.Interface()) | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func makeSlice(elemType reflect.Type) interface{} { | ||||
| 	if elemType.Kind() == reflect.Slice { | ||||
| 		elemType = elemType.Elem() | ||||
| 	} | ||||
| 	sliceType := reflect.SliceOf(elemType) | ||||
| 	slice := reflect.New(sliceType) | ||||
| 	slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) | ||||
| 	return slice.Interface() | ||||
| } | ||||
| 
 | ||||
| func strInSlice(a string, list []string) bool { | ||||
| 	for _, b := range list { | ||||
| 		if b == a { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // getValueFromFields return given fields's value
 | ||||
| func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { | ||||
| 	// If value is a nil pointer, Indirect returns a zero Value!
 | ||||
| 	// Therefor we need to check for a zero value,
 | ||||
| 	// as FieldByName could panic
 | ||||
| 	if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { | ||||
| 		for _, fieldName := range fieldNames { | ||||
| 			if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { | ||||
| 				result := fieldValue.Interface() | ||||
| 				if r, ok := result.(driver.Valuer); ok { | ||||
| 					result, _ = r.Value() | ||||
| 				} | ||||
| 				results = append(results, result) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func addExtraSpaceIfExist(str string) string { | ||||
| 	if str != "" { | ||||
| 		return " " + str | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
							
								
								
									
										154
									
								
								wercker.yml
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								wercker.yml
									
									
									
									
									
								
							| @ -1,154 +0,0 @@ | ||||
| # use the default golang container from Docker Hub | ||||
| box: golang | ||||
| 
 | ||||
| services: | ||||
|     - name: mariadb | ||||
|       id: mariadb:latest | ||||
|       env: | ||||
|           MYSQL_DATABASE: gorm | ||||
|           MYSQL_USER: gorm | ||||
|           MYSQL_PASSWORD: gorm | ||||
|           MYSQL_RANDOM_ROOT_PASSWORD: "yes" | ||||
|     - name: mysql | ||||
|       id: mysql:latest | ||||
|       env: | ||||
|           MYSQL_DATABASE: gorm | ||||
|           MYSQL_USER: gorm | ||||
|           MYSQL_PASSWORD: gorm | ||||
|           MYSQL_RANDOM_ROOT_PASSWORD: "yes" | ||||
|     - name: mysql57 | ||||
|       id: mysql:5.7 | ||||
|       env: | ||||
|           MYSQL_DATABASE: gorm | ||||
|           MYSQL_USER: gorm | ||||
|           MYSQL_PASSWORD: gorm | ||||
|           MYSQL_RANDOM_ROOT_PASSWORD: "yes" | ||||
|     - name: mysql56 | ||||
|       id: mysql:5.6 | ||||
|       env: | ||||
|           MYSQL_DATABASE: gorm | ||||
|           MYSQL_USER: gorm | ||||
|           MYSQL_PASSWORD: gorm | ||||
|           MYSQL_RANDOM_ROOT_PASSWORD: "yes" | ||||
|     - name: postgres | ||||
|       id: postgres:latest | ||||
|       env: | ||||
|           POSTGRES_USER: gorm | ||||
|           POSTGRES_PASSWORD: gorm | ||||
|           POSTGRES_DB: gorm | ||||
|     - name: postgres96 | ||||
|       id: postgres:9.6 | ||||
|       env: | ||||
|           POSTGRES_USER: gorm | ||||
|           POSTGRES_PASSWORD: gorm | ||||
|           POSTGRES_DB: gorm | ||||
|     - name: postgres95 | ||||
|       id: postgres:9.5 | ||||
|       env: | ||||
|           POSTGRES_USER: gorm | ||||
|           POSTGRES_PASSWORD: gorm | ||||
|           POSTGRES_DB: gorm | ||||
|     - name: postgres94 | ||||
|       id: postgres:9.4 | ||||
|       env: | ||||
|           POSTGRES_USER: gorm | ||||
|           POSTGRES_PASSWORD: gorm | ||||
|           POSTGRES_DB: gorm | ||||
|     - name: postgres93 | ||||
|       id: postgres:9.3 | ||||
|       env: | ||||
|           POSTGRES_USER: gorm | ||||
|           POSTGRES_PASSWORD: gorm | ||||
|           POSTGRES_DB: gorm | ||||
|     - name: mssql | ||||
|       id: mcmoe/mssqldocker:latest | ||||
|       env: | ||||
|         ACCEPT_EULA: Y | ||||
|         SA_PASSWORD: LoremIpsum86 | ||||
|         MSSQL_DB: gorm | ||||
|         MSSQL_USER: gorm | ||||
|         MSSQL_PASSWORD: LoremIpsum86 | ||||
| 
 | ||||
| # The steps that will be executed in the build pipeline | ||||
| build: | ||||
|     # The steps that will be executed on build | ||||
|     steps: | ||||
|         # Sets the go workspace and places you package | ||||
|         # at the right place in the workspace tree | ||||
|         - setup-go-workspace | ||||
| 
 | ||||
|         # Gets the dependencies | ||||
|         - script: | ||||
|                 name: go get | ||||
|                 code: | | ||||
|                     cd $WERCKER_SOURCE_DIR | ||||
|                     go version | ||||
|                     go get -t -v ./... | ||||
| 
 | ||||
|         # Build the project | ||||
|         - script: | ||||
|                 name: go build | ||||
|                 code: | | ||||
|                     go build ./... | ||||
| 
 | ||||
|         # Test the project | ||||
|         - script: | ||||
|                 name: test sqlite | ||||
|                 code: | | ||||
|                     go test -race -v ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test mariadb | ||||
|                 code: | | ||||
|                     GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test mysql | ||||
|                 code: | | ||||
|                     GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test mysql5.7 | ||||
|                 code: | | ||||
|                     GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test mysql5.6 | ||||
|                 code: | | ||||
|                     GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test postgres | ||||
|                 code: | | ||||
|                     GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test postgres96 | ||||
|                 code: | | ||||
|                     GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test postgres95 | ||||
|                 code: | | ||||
|                     GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test postgres94 | ||||
|                 code: | | ||||
|                     GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test postgres93 | ||||
|                 code: | | ||||
|                     GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: test mssql | ||||
|                 code: | | ||||
|                     GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... | ||||
| 
 | ||||
|         - script: | ||||
|                 name: codecov | ||||
|                 code: | | ||||
|                     go test -race -coverprofile=coverage.txt -covermode=atomic ./... | ||||
|                     bash <(curl -s https://codecov.io/bash) | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu