commit
						9d57c6b961
					
				
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					documents
 | 
				
			||||||
 | 
					_book
 | 
				
			||||||
							
								
								
									
										598
									
								
								association.go
									
									
									
									
									
								
							
							
						
						
									
										598
									
								
								association.go
									
									
									
									
									
								
							@ -4,32 +4,289 @@ import (
 | 
				
			|||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Association Mode contains some helper methods to handle relationship things easily.
 | 
				
			||||||
type Association struct {
 | 
					type Association struct {
 | 
				
			||||||
	Scope  *Scope
 | 
					 | 
				
			||||||
	Column string
 | 
					 | 
				
			||||||
	Error  error
 | 
						Error  error
 | 
				
			||||||
	Field  *Field
 | 
						scope  *Scope
 | 
				
			||||||
 | 
						column string
 | 
				
			||||||
 | 
						field  *Field
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (association *Association) setErr(err error) *Association {
 | 
					// Find find out all related associations
 | 
				
			||||||
	if err != nil {
 | 
					func (association *Association) Find(value interface{}) *Association {
 | 
				
			||||||
		association.Error = err
 | 
						association.scope.related(value, association.column)
 | 
				
			||||||
 | 
						return association.setErr(association.scope.db.Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
 | 
				
			||||||
 | 
					func (association *Association) Append(values ...interface{}) *Association {
 | 
				
			||||||
 | 
						if relationship := association.field.Relationship; relationship.Kind == "has_one" {
 | 
				
			||||||
 | 
							return association.Replace(values...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return association.saveAssociations(values...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Replace replace current associations with new one
 | 
				
			||||||
 | 
					func (association *Association) Replace(values ...interface{}) *Association {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							relationship = association.field.Relationship
 | 
				
			||||||
 | 
							scope        = association.scope
 | 
				
			||||||
 | 
							field        = association.field.Field
 | 
				
			||||||
 | 
							newDB        = scope.NewDB()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Append new values
 | 
				
			||||||
 | 
						association.field.Set(reflect.Zero(association.field.Field.Type()))
 | 
				
			||||||
 | 
						association.saveAssociations(values...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Belongs To
 | 
				
			||||||
 | 
						if relationship.Kind == "belongs_to" {
 | 
				
			||||||
 | 
							// Set foreign key to be null when clearing value (length equals 0)
 | 
				
			||||||
 | 
							if len(values) == 0 {
 | 
				
			||||||
 | 
								// Set foreign key to be nil
 | 
				
			||||||
 | 
								var foreignKeyMap = map[string]interface{}{}
 | 
				
			||||||
 | 
								for _, foreignKey := range relationship.ForeignDBNames {
 | 
				
			||||||
 | 
									foreignKeyMap[foreignKey] = nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							// Polymorphic Relations
 | 
				
			||||||
 | 
							if relationship.PolymorphicDBName != "" {
 | 
				
			||||||
 | 
								newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Delete Relations except new created
 | 
				
			||||||
 | 
							if len(values) > 0 {
 | 
				
			||||||
 | 
								var associationForeignFieldNames []string
 | 
				
			||||||
 | 
								if relationship.Kind == "many_to_many" {
 | 
				
			||||||
 | 
									// if many to many relations, get association fields name from association foreign keys
 | 
				
			||||||
 | 
									associationScope := scope.New(reflect.New(field.Type()).Interface())
 | 
				
			||||||
 | 
									for _, dbName := range relationship.AssociationForeignFieldNames {
 | 
				
			||||||
 | 
										if field, ok := associationScope.FieldByName(dbName); ok {
 | 
				
			||||||
 | 
											associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									// If other relations, use primary keys
 | 
				
			||||||
 | 
									for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
 | 
				
			||||||
 | 
										associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if len(newPrimaryKeys) > 0 {
 | 
				
			||||||
 | 
									sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
 | 
				
			||||||
 | 
									newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if relationship.Kind == "many_to_many" {
 | 
				
			||||||
 | 
								// if many to many relations, delete related relations from join table
 | 
				
			||||||
 | 
								var sourceForeignFieldNames []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for _, dbName := range relationship.ForeignFieldNames {
 | 
				
			||||||
 | 
									if field, ok := scope.FieldByName(dbName); ok {
 | 
				
			||||||
 | 
										sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
 | 
				
			||||||
 | 
									newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
 | 
				
			||||||
 | 
								// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
 | 
				
			||||||
 | 
								var foreignKeyMap = map[string]interface{}{}
 | 
				
			||||||
 | 
								for idx, foreignKey := range relationship.ForeignDBNames {
 | 
				
			||||||
 | 
									foreignKeyMap[foreignKey] = nil
 | 
				
			||||||
 | 
									if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
 | 
				
			||||||
 | 
										newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								fieldValue := reflect.New(association.field.Field.Type()).Interface()
 | 
				
			||||||
 | 
								association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return association
 | 
						return association
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (association *Association) Find(value interface{}) *Association {
 | 
					// Delete remove relationship between source & passed arguments, but won't delete those arguments
 | 
				
			||||||
	association.Scope.related(value, association.Column)
 | 
					func (association *Association) Delete(values ...interface{}) *Association {
 | 
				
			||||||
	return association.setErr(association.Scope.db.Error)
 | 
						var (
 | 
				
			||||||
 | 
							relationship = association.field.Relationship
 | 
				
			||||||
 | 
							scope        = association.scope
 | 
				
			||||||
 | 
							field        = association.field.Field
 | 
				
			||||||
 | 
							newDB        = scope.NewDB()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(values) == 0 {
 | 
				
			||||||
 | 
							return association
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
 | 
				
			||||||
 | 
						for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
 | 
				
			||||||
 | 
							deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
 | 
				
			||||||
 | 
							deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relationship.Kind == "many_to_many" {
 | 
				
			||||||
 | 
							// source value's foreign keys
 | 
				
			||||||
 | 
							for idx, foreignKey := range relationship.ForeignDBNames {
 | 
				
			||||||
 | 
								if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
 | 
				
			||||||
 | 
									newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// get association's foreign fields name
 | 
				
			||||||
 | 
							var associationScope = scope.New(reflect.New(field.Type()).Interface())
 | 
				
			||||||
 | 
							var associationForeignFieldNames []string
 | 
				
			||||||
 | 
							for _, associationDBName := range relationship.AssociationForeignFieldNames {
 | 
				
			||||||
 | 
								if field, ok := associationScope.FieldByName(associationDBName); ok {
 | 
				
			||||||
 | 
									associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// association value's foreign keys
 | 
				
			||||||
 | 
							deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
 | 
				
			||||||
 | 
							sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
 | 
				
			||||||
 | 
							newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							var foreignKeyMap = map[string]interface{}{}
 | 
				
			||||||
 | 
							for _, foreignKey := range relationship.ForeignDBNames {
 | 
				
			||||||
 | 
								foreignKeyMap[foreignKey] = nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if relationship.Kind == "belongs_to" {
 | 
				
			||||||
 | 
								// find with deleting relation's foreign keys
 | 
				
			||||||
 | 
								primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
 | 
				
			||||||
 | 
								newDB = newDB.Where(
 | 
				
			||||||
 | 
									fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 | 
				
			||||||
 | 
									toQueryValues(primaryKeys)...,
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// set foreign key to be null if there are some records affected
 | 
				
			||||||
 | 
								modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
 | 
				
			||||||
 | 
								if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
 | 
				
			||||||
 | 
									if results.RowsAffected > 0 {
 | 
				
			||||||
 | 
										scope.updatedAttrsWithValues(foreignKeyMap)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									association.setErr(results.Error)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
 | 
				
			||||||
 | 
								// find all relations
 | 
				
			||||||
 | 
								primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
 | 
				
			||||||
 | 
								newDB = newDB.Where(
 | 
				
			||||||
 | 
									fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 | 
				
			||||||
 | 
									toQueryValues(primaryKeys)...,
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// only include those deleting relations
 | 
				
			||||||
 | 
								newDB = newDB.Where(
 | 
				
			||||||
 | 
									fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
 | 
				
			||||||
 | 
									toQueryValues(deletingPrimaryKeys)...,
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// set matched relation's foreign key to be null
 | 
				
			||||||
 | 
								fieldValue := reflect.New(association.field.Field.Type()).Interface()
 | 
				
			||||||
 | 
								association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remove deleted records from source's field
 | 
				
			||||||
 | 
						if association.Error == nil {
 | 
				
			||||||
 | 
							if field.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
								leftValues := reflect.Zero(field.Type())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for i := 0; i < field.Len(); i++ {
 | 
				
			||||||
 | 
									reflectValue := field.Index(i)
 | 
				
			||||||
 | 
									primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
 | 
				
			||||||
 | 
									var isDeleted = false
 | 
				
			||||||
 | 
									for _, pk := range deletingPrimaryKeys {
 | 
				
			||||||
 | 
										if equalAsString(primaryKey, pk) {
 | 
				
			||||||
 | 
											isDeleted = true
 | 
				
			||||||
 | 
											break
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if !isDeleted {
 | 
				
			||||||
 | 
										leftValues = reflect.Append(leftValues, reflectValue)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								association.field.Set(leftValues)
 | 
				
			||||||
 | 
							} else if field.Kind() == reflect.Struct {
 | 
				
			||||||
 | 
								primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
 | 
				
			||||||
 | 
								for _, pk := range deletingPrimaryKeys {
 | 
				
			||||||
 | 
									if equalAsString(primaryKey, pk) {
 | 
				
			||||||
 | 
										association.field.Set(reflect.Zero(field.Type()))
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return association
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Clear remove relationship between source & current associations, won't delete those associations
 | 
				
			||||||
 | 
					func (association *Association) Clear() *Association {
 | 
				
			||||||
 | 
						return association.Replace()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Count return the count of current associations
 | 
				
			||||||
 | 
					func (association *Association) Count() int {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							count        = 0
 | 
				
			||||||
 | 
							relationship = association.field.Relationship
 | 
				
			||||||
 | 
							scope        = association.scope
 | 
				
			||||||
 | 
							fieldValue   = association.field.Field.Interface()
 | 
				
			||||||
 | 
							query        = scope.DB()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relationship.Kind == "many_to_many" {
 | 
				
			||||||
 | 
							query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
 | 
				
			||||||
 | 
						} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 | 
				
			||||||
 | 
							primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
 | 
				
			||||||
 | 
							query = query.Where(
 | 
				
			||||||
 | 
								fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 | 
				
			||||||
 | 
								toQueryValues(primaryKeys)...,
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						} else if relationship.Kind == "belongs_to" {
 | 
				
			||||||
 | 
							primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
 | 
				
			||||||
 | 
							query = query.Where(
 | 
				
			||||||
 | 
								fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
 | 
				
			||||||
 | 
								toQueryValues(primaryKeys)...,
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relationship.PolymorphicType != "" {
 | 
				
			||||||
 | 
							query = query.Where(
 | 
				
			||||||
 | 
								fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
 | 
				
			||||||
 | 
								scope.TableName(),
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						query.Model(fieldValue).Count(&count)
 | 
				
			||||||
 | 
						return count
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// saveAssociations save passed values as associations
 | 
				
			||||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
 | 
					func (association *Association) saveAssociations(values ...interface{}) *Association {
 | 
				
			||||||
	scope := association.Scope
 | 
						var (
 | 
				
			||||||
	field := association.Field
 | 
							scope        = association.scope
 | 
				
			||||||
	relationship := association.Field.Relationship
 | 
							field        = association.field
 | 
				
			||||||
 | 
							relationship = field.Relationship
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	saveAssociation := func(reflectValue reflect.Value) {
 | 
						saveAssociation := func(reflectValue reflect.Value) {
 | 
				
			||||||
		// value has to been pointer
 | 
							// value has to been pointer
 | 
				
			||||||
@ -94,318 +351,9 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
 | 
				
			|||||||
	return association
 | 
						return association
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (association *Association) Append(values ...interface{}) *Association {
 | 
					func (association *Association) setErr(err error) *Association {
 | 
				
			||||||
	if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
 | 
						if err != nil {
 | 
				
			||||||
		return association.Replace(values...)
 | 
							association.Error = err
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return association.saveAssociations(values...)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (association *Association) Replace(values ...interface{}) *Association {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		relationship = association.Field.Relationship
 | 
					 | 
				
			||||||
		scope        = association.Scope
 | 
					 | 
				
			||||||
		field        = association.Field.Field
 | 
					 | 
				
			||||||
		newDB        = scope.NewDB()
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Append new values
 | 
					 | 
				
			||||||
	association.Field.Set(reflect.Zero(association.Field.Field.Type()))
 | 
					 | 
				
			||||||
	association.saveAssociations(values...)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Belongs To
 | 
					 | 
				
			||||||
	if relationship.Kind == "belongs_to" {
 | 
					 | 
				
			||||||
		// Set foreign key to be null only when clearing value
 | 
					 | 
				
			||||||
		if len(values) == 0 {
 | 
					 | 
				
			||||||
			// Set foreign key to be nil
 | 
					 | 
				
			||||||
			var foreignKeyMap = map[string]interface{}{}
 | 
					 | 
				
			||||||
			for _, foreignKey := range relationship.ForeignDBNames {
 | 
					 | 
				
			||||||
				foreignKeyMap[foreignKey] = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		// Relations
 | 
					 | 
				
			||||||
		if relationship.PolymorphicDBName != "" {
 | 
					 | 
				
			||||||
			newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Relations except new created
 | 
					 | 
				
			||||||
		if len(values) > 0 {
 | 
					 | 
				
			||||||
			var newPrimaryKeys [][]interface{}
 | 
					 | 
				
			||||||
			var associationForeignFieldNames []string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if relationship.Kind == "many_to_many" {
 | 
					 | 
				
			||||||
				// If many to many relations, get it from foreign key
 | 
					 | 
				
			||||||
				associationForeignFieldNames = relationship.AssociationForeignFieldNames
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				// If other relations, get real primary keys
 | 
					 | 
				
			||||||
				for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
 | 
					 | 
				
			||||||
					if field.IsPrimaryKey {
 | 
					 | 
				
			||||||
						associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if len(newPrimaryKeys) > 0 {
 | 
					 | 
				
			||||||
				sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
 | 
					 | 
				
			||||||
				newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if relationship.Kind == "many_to_many" {
 | 
					 | 
				
			||||||
			for idx, foreignKey := range relationship.ForeignDBNames {
 | 
					 | 
				
			||||||
				if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
 | 
					 | 
				
			||||||
					newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
 | 
					 | 
				
			||||||
		} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
 | 
					 | 
				
			||||||
			var foreignKeyMap = map[string]interface{}{}
 | 
					 | 
				
			||||||
			for idx, foreignKey := range relationship.ForeignDBNames {
 | 
					 | 
				
			||||||
				foreignKeyMap[foreignKey] = nil
 | 
					 | 
				
			||||||
				if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
 | 
					 | 
				
			||||||
					newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			fieldValue := reflect.New(association.Field.Field.Type()).Interface()
 | 
					 | 
				
			||||||
			association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return association
 | 
						return association
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (association *Association) Delete(values ...interface{}) *Association {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		relationship = association.Field.Relationship
 | 
					 | 
				
			||||||
		scope        = association.Scope
 | 
					 | 
				
			||||||
		field        = association.Field.Field
 | 
					 | 
				
			||||||
		newDB        = scope.NewDB()
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(values) == 0 {
 | 
					 | 
				
			||||||
		return association
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
 | 
					 | 
				
			||||||
	for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
 | 
					 | 
				
			||||||
		if field.IsPrimaryKey {
 | 
					 | 
				
			||||||
			deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
 | 
					 | 
				
			||||||
			deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if relationship.Kind == "many_to_many" {
 | 
					 | 
				
			||||||
		// source value's foreign keys
 | 
					 | 
				
			||||||
		for idx, foreignKey := range relationship.ForeignDBNames {
 | 
					 | 
				
			||||||
			if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
 | 
					 | 
				
			||||||
				newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// association value's foreign keys
 | 
					 | 
				
			||||||
		deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
 | 
					 | 
				
			||||||
		sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
 | 
					 | 
				
			||||||
		newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		var foreignKeyMap = map[string]interface{}{}
 | 
					 | 
				
			||||||
		for _, foreignKey := range relationship.ForeignDBNames {
 | 
					 | 
				
			||||||
			foreignKeyMap[foreignKey] = nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if relationship.Kind == "belongs_to" {
 | 
					 | 
				
			||||||
			// find with deleting relation's foreign keys
 | 
					 | 
				
			||||||
			primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
 | 
					 | 
				
			||||||
			newDB = newDB.Where(
 | 
					 | 
				
			||||||
				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 | 
					 | 
				
			||||||
				toQueryValues(primaryKeys)...,
 | 
					 | 
				
			||||||
			)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// set foreign key to be null
 | 
					 | 
				
			||||||
			modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
 | 
					 | 
				
			||||||
			if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
 | 
					 | 
				
			||||||
				if results.RowsAffected > 0 {
 | 
					 | 
				
			||||||
					scope.updatedAttrsWithValues(foreignKeyMap, false)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				association.setErr(results.Error)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
 | 
					 | 
				
			||||||
			// find all relations
 | 
					 | 
				
			||||||
			primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value)
 | 
					 | 
				
			||||||
			newDB = newDB.Where(
 | 
					 | 
				
			||||||
				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 | 
					 | 
				
			||||||
				toQueryValues(primaryKeys)...,
 | 
					 | 
				
			||||||
			)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// only include those deleting relations
 | 
					 | 
				
			||||||
			newDB = newDB.Where(
 | 
					 | 
				
			||||||
				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
 | 
					 | 
				
			||||||
				toQueryValues(deletingPrimaryKeys)...,
 | 
					 | 
				
			||||||
			)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// set matched relation's foreign key to be null
 | 
					 | 
				
			||||||
			fieldValue := reflect.New(association.Field.Field.Type()).Interface()
 | 
					 | 
				
			||||||
			association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Remove deleted records from field
 | 
					 | 
				
			||||||
	if association.Error == nil {
 | 
					 | 
				
			||||||
		if association.Field.Field.Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
			leftValues := reflect.Zero(association.Field.Field.Type())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for i := 0; i < association.Field.Field.Len(); i++ {
 | 
					 | 
				
			||||||
				reflectValue := association.Field.Field.Index(i)
 | 
					 | 
				
			||||||
				primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
 | 
					 | 
				
			||||||
				var included = false
 | 
					 | 
				
			||||||
				for _, pk := range deletingPrimaryKeys {
 | 
					 | 
				
			||||||
					if equalAsString(primaryKey, pk) {
 | 
					 | 
				
			||||||
						included = true
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if !included {
 | 
					 | 
				
			||||||
					leftValues = reflect.Append(leftValues, reflectValue)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			association.Field.Set(leftValues)
 | 
					 | 
				
			||||||
		} else if association.Field.Field.Kind() == reflect.Struct {
 | 
					 | 
				
			||||||
			primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
 | 
					 | 
				
			||||||
			for _, pk := range deletingPrimaryKeys {
 | 
					 | 
				
			||||||
				if equalAsString(primaryKey, pk) {
 | 
					 | 
				
			||||||
					association.Field.Set(reflect.Zero(association.Field.Field.Type()))
 | 
					 | 
				
			||||||
					break
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return association
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (association *Association) Clear() *Association {
 | 
					 | 
				
			||||||
	return association.Replace()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (association *Association) Count() int {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		count        = 0
 | 
					 | 
				
			||||||
		relationship = association.Field.Relationship
 | 
					 | 
				
			||||||
		scope        = association.Scope
 | 
					 | 
				
			||||||
		fieldValue   = association.Field.Field.Interface()
 | 
					 | 
				
			||||||
		newScope     = scope.New(fieldValue)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if relationship.Kind == "many_to_many" {
 | 
					 | 
				
			||||||
		relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
 | 
					 | 
				
			||||||
	} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 | 
					 | 
				
			||||||
		query := scope.DB()
 | 
					 | 
				
			||||||
		for idx, foreignKey := range relationship.ForeignDBNames {
 | 
					 | 
				
			||||||
			if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
 | 
					 | 
				
			||||||
				query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
 | 
					 | 
				
			||||||
					field.Field.Interface())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if relationship.PolymorphicType != "" {
 | 
					 | 
				
			||||||
			query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		query.Model(fieldValue).Count(&count)
 | 
					 | 
				
			||||||
	} else if relationship.Kind == "belongs_to" {
 | 
					 | 
				
			||||||
		query := scope.DB()
 | 
					 | 
				
			||||||
		for idx, primaryKey := range relationship.AssociationForeignDBNames {
 | 
					 | 
				
			||||||
			if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
 | 
					 | 
				
			||||||
				query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),
 | 
					 | 
				
			||||||
					field.Field.Interface())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		query.Model(fieldValue).Count(&count)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return count
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
 | 
					 | 
				
			||||||
	scope := association.Scope
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, value := range values {
 | 
					 | 
				
			||||||
		reflectValue := reflect.Indirect(reflect.ValueOf(value))
 | 
					 | 
				
			||||||
		if reflectValue.Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
			for i := 0; i < reflectValue.Len(); i++ {
 | 
					 | 
				
			||||||
				primaryKeys := []interface{}{}
 | 
					 | 
				
			||||||
				newScope := scope.New(reflectValue.Index(i).Interface())
 | 
					 | 
				
			||||||
				for _, column := range columns {
 | 
					 | 
				
			||||||
					if field, ok := newScope.FieldByName(column); ok {
 | 
					 | 
				
			||||||
						primaryKeys = append(primaryKeys, field.Field.Interface())
 | 
					 | 
				
			||||||
					} else {
 | 
					 | 
				
			||||||
						primaryKeys = append(primaryKeys, "")
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				results = append(results, primaryKeys)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else if reflectValue.Kind() == reflect.Struct {
 | 
					 | 
				
			||||||
			newScope := scope.New(value)
 | 
					 | 
				
			||||||
			var primaryKeys []interface{}
 | 
					 | 
				
			||||||
			for _, column := range columns {
 | 
					 | 
				
			||||||
				if field, ok := newScope.FieldByName(column); ok {
 | 
					 | 
				
			||||||
					primaryKeys = append(primaryKeys, field.Field.Interface())
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					primaryKeys = append(primaryKeys, "")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			results = append(results, primaryKeys)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func toQueryMarks(primaryValues [][]interface{}) string {
 | 
					 | 
				
			||||||
	var results []string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, primaryValue := range primaryValues {
 | 
					 | 
				
			||||||
		var marks []string
 | 
					 | 
				
			||||||
		for _, _ = range primaryValue {
 | 
					 | 
				
			||||||
			marks = append(marks, "?")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if len(marks) > 1 {
 | 
					 | 
				
			||||||
			results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			results = append(results, strings.Join(marks, ""))
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return strings.Join(results, ",")
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func toQueryCondition(scope *Scope, columns []string) string {
 | 
					 | 
				
			||||||
	var newColumns []string
 | 
					 | 
				
			||||||
	for _, column := range columns {
 | 
					 | 
				
			||||||
		newColumns = append(newColumns, scope.Quote(column))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(columns) > 1 {
 | 
					 | 
				
			||||||
		return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		return strings.Join(newColumns, ",")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
 | 
					 | 
				
			||||||
	for _, primaryValue := range primaryValues {
 | 
					 | 
				
			||||||
		for _, value := range primaryValue {
 | 
					 | 
				
			||||||
			values = append(values, value)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return values
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -5,6 +5,8 @@ import (
 | 
				
			|||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/jinzhu/gorm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestBelongsTo(t *testing.T) {
 | 
					func TestBelongsTo(t *testing.T) {
 | 
				
			||||||
@ -16,7 +18,7 @@ func TestBelongsTo(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Save(&post).Error; err != nil {
 | 
						if err := DB.Save(&post).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Got errors when save post", err.Error())
 | 
							t.Error("Got errors when save post", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if post.Category.ID == 0 || post.MainCategory.ID == 0 {
 | 
						if post.Category.ID == 0 || post.MainCategory.ID == 0 {
 | 
				
			||||||
@ -177,6 +179,49 @@ func TestBelongsTo(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBelongsToOverrideForeignKey1(t *testing.T) {
 | 
				
			||||||
 | 
						type Profile struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Name string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type User struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Profile      Profile `gorm:"ForeignKey:ProfileRefer"`
 | 
				
			||||||
 | 
							ProfileRefer int
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
 | 
				
			||||||
 | 
							if relation.Relationship.Kind != "belongs_to" ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
 | 
				
			||||||
 | 
								t.Errorf("Override belongs to foreign key with tag")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBelongsToOverrideForeignKey2(t *testing.T) {
 | 
				
			||||||
 | 
						type Profile struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Refer string
 | 
				
			||||||
 | 
							Name  string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type User struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Profile   Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
 | 
				
			||||||
 | 
							ProfileID int
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
 | 
				
			||||||
 | 
							if relation.Relationship.Kind != "belongs_to" ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
 | 
				
			||||||
 | 
								t.Errorf("Override belongs to foreign key with tag")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestHasOne(t *testing.T) {
 | 
					func TestHasOne(t *testing.T) {
 | 
				
			||||||
	user := User{
 | 
						user := User{
 | 
				
			||||||
		Name:       "has one",
 | 
							Name:       "has one",
 | 
				
			||||||
@ -184,7 +229,7 @@ func TestHasOne(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Save(&user).Error; err != nil {
 | 
						if err := DB.Save(&user).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Got errors when save user", err.Error())
 | 
							t.Error("Got errors when save user", err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if user.CreditCard.UserId.Int64 == 0 {
 | 
						if user.CreditCard.UserId.Int64 == 0 {
 | 
				
			||||||
@ -323,6 +368,49 @@ func TestHasOne(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHasOneOverrideForeignKey1(t *testing.T) {
 | 
				
			||||||
 | 
						type Profile struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Name      string
 | 
				
			||||||
 | 
							UserRefer uint
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type User struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Profile Profile `gorm:"ForeignKey:UserRefer"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
 | 
				
			||||||
 | 
							if relation.Relationship.Kind != "has_one" ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
 | 
				
			||||||
 | 
								t.Errorf("Override belongs to foreign key with tag")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHasOneOverrideForeignKey2(t *testing.T) {
 | 
				
			||||||
 | 
						type Profile struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Name   string
 | 
				
			||||||
 | 
							UserID uint
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type User struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Refer   string
 | 
				
			||||||
 | 
							Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
 | 
				
			||||||
 | 
							if relation.Relationship.Kind != "has_one" ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
 | 
				
			||||||
 | 
								t.Errorf("Override belongs to foreign key with tag")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestHasMany(t *testing.T) {
 | 
					func TestHasMany(t *testing.T) {
 | 
				
			||||||
	post := Post{
 | 
						post := Post{
 | 
				
			||||||
		Title:    "post has many",
 | 
							Title:    "post has many",
 | 
				
			||||||
@ -331,7 +419,7 @@ func TestHasMany(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Save(&post).Error; err != nil {
 | 
						if err := DB.Save(&post).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Got errors when save post", err.Error())
 | 
							t.Error("Got errors when save post", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, comment := range post.Comments {
 | 
						for _, comment := range post.Comments {
 | 
				
			||||||
@ -462,6 +550,49 @@ func TestHasMany(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHasManyOverrideForeignKey1(t *testing.T) {
 | 
				
			||||||
 | 
						type Profile struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Name      string
 | 
				
			||||||
 | 
							UserRefer uint
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type User struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Profile []Profile `gorm:"ForeignKey:UserRefer"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
 | 
				
			||||||
 | 
							if relation.Relationship.Kind != "has_many" ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
 | 
				
			||||||
 | 
								t.Errorf("Override belongs to foreign key with tag")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHasManyOverrideForeignKey2(t *testing.T) {
 | 
				
			||||||
 | 
						type Profile struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Name   string
 | 
				
			||||||
 | 
							UserID uint
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type User struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Refer   string
 | 
				
			||||||
 | 
							Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
 | 
				
			||||||
 | 
							if relation.Relationship.Kind != "has_many" ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
 | 
				
			||||||
 | 
								!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
 | 
				
			||||||
 | 
								t.Errorf("Override belongs to foreign key with tag")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestManyToMany(t *testing.T) {
 | 
					func TestManyToMany(t *testing.T) {
 | 
				
			||||||
	DB.Raw("delete from languages")
 | 
						DB.Raw("delete from languages")
 | 
				
			||||||
	var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
 | 
						var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										210
									
								
								callback.go
									
									
									
									
									
								
							
							
						
						
									
										210
									
								
								callback.go
									
									
									
									
									
								
							@ -4,34 +4,39 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type callback struct {
 | 
					// DefaultCallback default callbacks defined by gorm
 | 
				
			||||||
 | 
					var DefaultCallback = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Callback is a struct that contains all CURD callbacks
 | 
				
			||||||
 | 
					//   Field `creates` contains callbacks will be call when creating object
 | 
				
			||||||
 | 
					//   Field `updates` contains callbacks will be call when updating object
 | 
				
			||||||
 | 
					//   Field `deletes` contains callbacks will be call when deleting object
 | 
				
			||||||
 | 
					//   Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
 | 
				
			||||||
 | 
					//   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
 | 
				
			||||||
 | 
					//   Field `processors` contains all callback processors, will be used to generate above callbacks in order
 | 
				
			||||||
 | 
					type Callback struct {
 | 
				
			||||||
	creates    []*func(scope *Scope)
 | 
						creates    []*func(scope *Scope)
 | 
				
			||||||
	updates    []*func(scope *Scope)
 | 
						updates    []*func(scope *Scope)
 | 
				
			||||||
	deletes    []*func(scope *Scope)
 | 
						deletes    []*func(scope *Scope)
 | 
				
			||||||
	queries    []*func(scope *Scope)
 | 
						queries    []*func(scope *Scope)
 | 
				
			||||||
	rowQueries []*func(scope *Scope)
 | 
						rowQueries []*func(scope *Scope)
 | 
				
			||||||
	processors []*callbackProcessor
 | 
						processors []*CallbackProcessor
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type callbackProcessor struct {
 | 
					// CallbackProcessor contains callback informations
 | 
				
			||||||
	name      string
 | 
					type CallbackProcessor struct {
 | 
				
			||||||
	before    string
 | 
						name      string              // current callback's name
 | 
				
			||||||
	after     string
 | 
						before    string              // register current callback before a callback
 | 
				
			||||||
	replace   bool
 | 
						after     string              // register current callback after a callback
 | 
				
			||||||
	remove    bool
 | 
						replace   bool                // replace callbacks with same name
 | 
				
			||||||
	typ       string
 | 
						remove    bool                // delete callbacks with same name
 | 
				
			||||||
	processor *func(scope *Scope)
 | 
						kind      string              // callback type: create, update, delete, query, row_query
 | 
				
			||||||
	callback  *callback
 | 
						processor *func(scope *Scope) // callback handler
 | 
				
			||||||
 | 
						parent    *Callback
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *callback) addProcessor(typ string) *callbackProcessor {
 | 
					func (c *Callback) clone() *Callback {
 | 
				
			||||||
	cp := &callbackProcessor{typ: typ, callback: c}
 | 
						return &Callback{
 | 
				
			||||||
	c.processors = append(c.processors, cp)
 | 
					 | 
				
			||||||
	return cp
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c *callback) clone() *callback {
 | 
					 | 
				
			||||||
	return &callback{
 | 
					 | 
				
			||||||
		creates:    c.creates,
 | 
							creates:    c.creates,
 | 
				
			||||||
		updates:    c.updates,
 | 
							updates:    c.updates,
 | 
				
			||||||
		deletes:    c.deletes,
 | 
							deletes:    c.deletes,
 | 
				
			||||||
@ -40,57 +45,95 @@ func (c *callback) clone() *callback {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *callback) Create() *callbackProcessor {
 | 
					// Create could be used to register callbacks for creating object
 | 
				
			||||||
	return c.addProcessor("create")
 | 
					//     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
 | 
				
			||||||
 | 
					//       // business logic
 | 
				
			||||||
 | 
					//       ...
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//       // set error if some thing wrong happened, will rollback the creating
 | 
				
			||||||
 | 
					//       scope.Err(errors.New("error"))
 | 
				
			||||||
 | 
					//     })
 | 
				
			||||||
 | 
					func (c *Callback) Create() *CallbackProcessor {
 | 
				
			||||||
 | 
						return &CallbackProcessor{kind: "create", parent: c}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *callback) Update() *callbackProcessor {
 | 
					// Update could be used to register callbacks for updating object, refer `Create` for usage
 | 
				
			||||||
	return c.addProcessor("update")
 | 
					func (c *Callback) Update() *CallbackProcessor {
 | 
				
			||||||
 | 
						return &CallbackProcessor{kind: "update", parent: c}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *callback) Delete() *callbackProcessor {
 | 
					// Delete could be used to register callbacks for deleting object, refer `Create` for usage
 | 
				
			||||||
	return c.addProcessor("delete")
 | 
					func (c *Callback) Delete() *CallbackProcessor {
 | 
				
			||||||
 | 
						return &CallbackProcessor{kind: "delete", parent: c}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *callback) Query() *callbackProcessor {
 | 
					// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
 | 
				
			||||||
	return c.addProcessor("query")
 | 
					// Refer `Create` for usage
 | 
				
			||||||
 | 
					func (c *Callback) Query() *CallbackProcessor {
 | 
				
			||||||
 | 
						return &CallbackProcessor{kind: "query", parent: c}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *callback) RowQuery() *callbackProcessor {
 | 
					// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
 | 
				
			||||||
	return c.addProcessor("row_query")
 | 
					func (c *Callback) RowQuery() *CallbackProcessor {
 | 
				
			||||||
 | 
						return &CallbackProcessor{kind: "row_query", parent: c}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
 | 
					// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
 | 
				
			||||||
	cp.before = name
 | 
					func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
 | 
				
			||||||
 | 
						cp.after = callbackName
 | 
				
			||||||
	return cp
 | 
						return cp
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cp *callbackProcessor) After(name string) *callbackProcessor {
 | 
					// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
 | 
				
			||||||
	cp.after = name
 | 
					func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
 | 
				
			||||||
 | 
						cp.before = callbackName
 | 
				
			||||||
	return cp
 | 
						return cp
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
 | 
					// Register a new callback, refer `Callbacks.Create`
 | 
				
			||||||
	cp.name = name
 | 
					func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
 | 
				
			||||||
	cp.processor = &fc
 | 
						cp.name = callbackName
 | 
				
			||||||
	cp.callback.sort()
 | 
						cp.processor = &callback
 | 
				
			||||||
 | 
						cp.parent.processors = append(cp.parent.processors, cp)
 | 
				
			||||||
 | 
						cp.parent.reorder()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cp *callbackProcessor) Remove(name string) {
 | 
					// Remove a registered callback
 | 
				
			||||||
	fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
 | 
					//     db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
 | 
				
			||||||
	cp.name = name
 | 
					func (cp *CallbackProcessor) Remove(callbackName string) {
 | 
				
			||||||
 | 
						fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
 | 
				
			||||||
 | 
						cp.name = callbackName
 | 
				
			||||||
	cp.remove = true
 | 
						cp.remove = true
 | 
				
			||||||
	cp.callback.sort()
 | 
						cp.parent.processors = append(cp.parent.processors, cp)
 | 
				
			||||||
 | 
						cp.parent.reorder()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
 | 
					// Replace a registered callback with new callback
 | 
				
			||||||
	fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
 | 
					//     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
 | 
				
			||||||
	cp.name = name
 | 
					//		   scope.SetColumn("Created", now)
 | 
				
			||||||
	cp.processor = &fc
 | 
					//		   scope.SetColumn("Updated", now)
 | 
				
			||||||
 | 
					//     })
 | 
				
			||||||
 | 
					func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
 | 
				
			||||||
 | 
						fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
 | 
				
			||||||
 | 
						cp.name = callbackName
 | 
				
			||||||
 | 
						cp.processor = &callback
 | 
				
			||||||
	cp.replace = true
 | 
						cp.replace = true
 | 
				
			||||||
	cp.callback.sort()
 | 
						cp.parent.processors = append(cp.parent.processors, cp)
 | 
				
			||||||
 | 
						cp.parent.reorder()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get registered callback
 | 
				
			||||||
 | 
					//    db.Callback().Create().Get("gorm:create")
 | 
				
			||||||
 | 
					func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
 | 
				
			||||||
 | 
						for _, p := range cp.parent.processors {
 | 
				
			||||||
 | 
							if p.name == callbackName && p.kind == cp.kind && !cp.remove {
 | 
				
			||||||
 | 
								return *p.processor
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// getRIndex get right index from string slice
 | 
				
			||||||
func getRIndex(strs []string, str string) int {
 | 
					func getRIndex(strs []string, str string) int {
 | 
				
			||||||
	for i := len(strs) - 1; i >= 0; i-- {
 | 
						for i := len(strs) - 1; i >= 0; i-- {
 | 
				
			||||||
		if strs[i] == str {
 | 
							if strs[i] == str {
 | 
				
			||||||
@ -100,83 +143,77 @@ func getRIndex(strs []string, str string) int {
 | 
				
			|||||||
	return -1
 | 
						return -1
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
 | 
					// sortProcessors sort callback processors based on its before, after, remove, replace
 | 
				
			||||||
	var sortCallbackProcessor func(c *callbackProcessor)
 | 
					func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
 | 
				
			||||||
	var names, sortedNames = []string{}, []string{}
 | 
						var (
 | 
				
			||||||
 | 
							allNames, sortedNames []string
 | 
				
			||||||
 | 
							sortCallbackProcessor func(c *CallbackProcessor)
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, cp := range cps {
 | 
						for _, cp := range cps {
 | 
				
			||||||
		if index := getRIndex(names, cp.name); index > -1 {
 | 
							// show warning message the callback name already exists
 | 
				
			||||||
			if !cp.replace && !cp.remove {
 | 
							if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
 | 
				
			||||||
			fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
 | 
								fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		}
 | 
							allNames = append(allNames, cp.name)
 | 
				
			||||||
		names = append(names, cp.name)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sortCallbackProcessor = func(c *callbackProcessor) {
 | 
						sortCallbackProcessor = func(c *CallbackProcessor) {
 | 
				
			||||||
		if getRIndex(sortedNames, c.name) > -1 {
 | 
							if getRIndex(sortedNames, c.name) == -1 { // if not sorted
 | 
				
			||||||
			return
 | 
								if c.before != "" { // if defined before callback
 | 
				
			||||||
		}
 | 
									if index := getRIndex(sortedNames, c.before); index != -1 {
 | 
				
			||||||
 | 
										// if before callback already sorted, append current callback just after it
 | 
				
			||||||
		if len(c.before) > 0 {
 | 
					 | 
				
			||||||
			if index := getRIndex(sortedNames, c.before); index > -1 {
 | 
					 | 
				
			||||||
					sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
 | 
										sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
 | 
				
			||||||
			} else if index := getRIndex(names, c.before); index > -1 {
 | 
									} else if index := getRIndex(allNames, c.before); index != -1 {
 | 
				
			||||||
 | 
										// if before callback exists but haven't sorted, append current callback to last
 | 
				
			||||||
					sortedNames = append(sortedNames, c.name)
 | 
										sortedNames = append(sortedNames, c.name)
 | 
				
			||||||
					sortCallbackProcessor(cps[index])
 | 
										sortCallbackProcessor(cps[index])
 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				sortedNames = append(sortedNames, c.name)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(c.after) > 0 {
 | 
								if c.after != "" { // if defined after callback
 | 
				
			||||||
			if index := getRIndex(sortedNames, c.after); index > -1 {
 | 
									if index := getRIndex(sortedNames, c.after); index != -1 {
 | 
				
			||||||
 | 
										// if after callback already sorted, append current callback just before it
 | 
				
			||||||
					sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
 | 
										sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
 | 
				
			||||||
			} else if index := getRIndex(names, c.after); index > -1 {
 | 
									} else if index := getRIndex(allNames, c.after); index != -1 {
 | 
				
			||||||
 | 
										// if after callback exists but haven't sorted
 | 
				
			||||||
					cp := cps[index]
 | 
										cp := cps[index]
 | 
				
			||||||
				if len(cp.before) == 0 {
 | 
										// set after callback's before callback to current callback
 | 
				
			||||||
 | 
										if cp.before == "" {
 | 
				
			||||||
						cp.before = c.name
 | 
											cp.before = c.name
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
					sortCallbackProcessor(cp)
 | 
										sortCallbackProcessor(cp)
 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				sortedNames = append(sortedNames, c.name)
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// if current callback haven't been sorted, append it to last
 | 
				
			||||||
			if getRIndex(sortedNames, c.name) == -1 {
 | 
								if getRIndex(sortedNames, c.name) == -1 {
 | 
				
			||||||
				sortedNames = append(sortedNames, c.name)
 | 
									sortedNames = append(sortedNames, c.name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, cp := range cps {
 | 
						for _, cp := range cps {
 | 
				
			||||||
		sortCallbackProcessor(cp)
 | 
							sortCallbackProcessor(cp)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var funcs = []*func(scope *Scope){}
 | 
						var sortedFuncs []*func(scope *Scope)
 | 
				
			||||||
	var sortedFuncs = []*func(scope *Scope){}
 | 
					 | 
				
			||||||
	for _, name := range sortedNames {
 | 
						for _, name := range sortedNames {
 | 
				
			||||||
		index := getRIndex(names, name)
 | 
							if index := getRIndex(allNames, name); !cps[index].remove {
 | 
				
			||||||
		if !cps[index].remove {
 | 
					 | 
				
			||||||
			sortedFuncs = append(sortedFuncs, cps[index].processor)
 | 
								sortedFuncs = append(sortedFuncs, cps[index].processor)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, cp := range cps {
 | 
						return sortedFuncs
 | 
				
			||||||
		if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
 | 
					 | 
				
			||||||
			if !cp.remove {
 | 
					 | 
				
			||||||
				funcs = append(funcs, cp.processor)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return append(sortedFuncs, funcs...)
 | 
					// reorder all registered processors, and reset CURD callbacks
 | 
				
			||||||
}
 | 
					func (c *Callback) reorder() {
 | 
				
			||||||
 | 
						var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
 | 
				
			||||||
func (c *callback) sort() {
 | 
					 | 
				
			||||||
	var creates, updates, deletes, queries, rowQueries []*callbackProcessor
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, processor := range c.processors {
 | 
						for _, processor := range c.processors {
 | 
				
			||||||
		switch processor.typ {
 | 
							if processor.name != "" {
 | 
				
			||||||
 | 
								switch processor.kind {
 | 
				
			||||||
			case "create":
 | 
								case "create":
 | 
				
			||||||
				creates = append(creates, processor)
 | 
									creates = append(creates, processor)
 | 
				
			||||||
			case "update":
 | 
								case "update":
 | 
				
			||||||
@ -189,6 +226,7 @@ func (c *callback) sort() {
 | 
				
			|||||||
				rowQueries = append(rowQueries, processor)
 | 
									rowQueries = append(rowQueries, processor)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.creates = sortProcessors(creates)
 | 
						c.creates = sortProcessors(creates)
 | 
				
			||||||
	c.updates = sortProcessors(updates)
 | 
						c.updates = sortProcessors(updates)
 | 
				
			||||||
@ -196,5 +234,3 @@ func (c *callback) sort() {
 | 
				
			|||||||
	c.queries = sortProcessors(queries)
 | 
						c.queries = sortProcessors(queries)
 | 
				
			||||||
	c.rowQueries = sortProcessors(rowQueries)
 | 
						c.rowQueries = sortProcessors(rowQueries)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -5,12 +5,31 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BeforeCreate(scope *Scope) {
 | 
					// Define callbacks for creating
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("BeforeSave")
 | 
					func init() {
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("BeforeCreate")
 | 
						DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:create", createCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func UpdateTimeStampWhenCreate(scope *Scope) {
 | 
					// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
 | 
				
			||||||
 | 
					func beforeCreateCallback(scope *Scope) {
 | 
				
			||||||
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("BeforeSave")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("BeforeCreate")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
 | 
				
			||||||
 | 
					func updateTimeStampForCreateCallback(scope *Scope) {
 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		now := NowFunc()
 | 
							now := NowFunc()
 | 
				
			||||||
		scope.SetColumn("CreatedAt", now)
 | 
							scope.SetColumn("CreatedAt", now)
 | 
				
			||||||
@ -18,109 +37,108 @@ func UpdateTimeStampWhenCreate(scope *Scope) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Create(scope *Scope) {
 | 
					// createCallback the callback used to insert data into database
 | 
				
			||||||
	defer scope.Trace(NowFunc())
 | 
					func createCallback(scope *Scope) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		// set create sql
 | 
							defer scope.trace(NowFunc())
 | 
				
			||||||
		var sqls, columns []string
 | 
					
 | 
				
			||||||
		fields := scope.Fields()
 | 
							var (
 | 
				
			||||||
		for _, field := range fields {
 | 
								columns, placeholders        []string
 | 
				
			||||||
 | 
								blankColumnsWithDefaultValue []string
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, field := range scope.Fields() {
 | 
				
			||||||
			if scope.changeableField(field) {
 | 
								if scope.changeableField(field) {
 | 
				
			||||||
				if field.IsNormal {
 | 
									if field.IsNormal {
 | 
				
			||||||
					if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
 | 
										if !field.IsPrimaryKey || !field.IsBlank {
 | 
				
			||||||
						if !field.IsBlank || !field.HasDefaultValue {
 | 
											if field.IsBlank && field.HasDefaultValue {
 | 
				
			||||||
 | 
												blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
 | 
				
			||||||
 | 
												scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
 | 
				
			||||||
 | 
											} else {
 | 
				
			||||||
							columns = append(columns, scope.Quote(field.DBName))
 | 
												columns = append(columns, scope.Quote(field.DBName))
 | 
				
			||||||
							sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
 | 
												placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
 | 
				
			||||||
						} else if field.HasDefaultValue {
 | 
					 | 
				
			||||||
							var hasDefaultValueColumns []string
 | 
					 | 
				
			||||||
							if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
 | 
					 | 
				
			||||||
								hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
							hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
 | 
					 | 
				
			||||||
							scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
									} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
 | 
				
			||||||
					for _, dbName := range relationship.ForeignDBNames {
 | 
										for _, foreignKey := range field.Relationship.ForeignDBNames {
 | 
				
			||||||
						if relationField := fields[dbName]; !scope.changeableField(relationField) {
 | 
											if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
 | 
				
			||||||
							columns = append(columns, scope.Quote(relationField.DBName))
 | 
												columns = append(columns, scope.Quote(foreignField.DBName))
 | 
				
			||||||
							sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
 | 
												placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		returningKey := "*"
 | 
							var (
 | 
				
			||||||
		primaryField := scope.PrimaryField()
 | 
								returningColumn = "*"
 | 
				
			||||||
		if primaryField != nil {
 | 
								quotedTableName = scope.QuotedTableName()
 | 
				
			||||||
			returningKey = scope.Quote(primaryField.DBName)
 | 
								primaryField    = scope.PrimaryField()
 | 
				
			||||||
 | 
								extraOption     string
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if str, ok := scope.Get("gorm:insert_option"); ok {
 | 
				
			||||||
 | 
								extraOption = fmt.Sprint(str)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if primaryField != nil {
 | 
				
			||||||
 | 
								returningColumn = scope.Quote(primaryField.DBName)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(columns) == 0 {
 | 
							if len(columns) == 0 {
 | 
				
			||||||
			scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
 | 
								scope.Raw(fmt.Sprintf(
 | 
				
			||||||
				scope.QuotedTableName(),
 | 
									"INSERT INTO %v DEFAULT VALUES%v%v",
 | 
				
			||||||
				scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
 | 
									quotedTableName,
 | 
				
			||||||
 | 
									addExtraSpaceIfExist(extraOption),
 | 
				
			||||||
 | 
									addExtraSpaceIfExist(lastInsertIDReturningSuffix),
 | 
				
			||||||
			))
 | 
								))
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			scope.Raw(fmt.Sprintf(
 | 
								scope.Raw(fmt.Sprintf(
 | 
				
			||||||
				"INSERT INTO %v (%v) VALUES (%v) %v",
 | 
									"INSERT INTO %v (%v) VALUES (%v)%v%v",
 | 
				
			||||||
				scope.QuotedTableName(),
 | 
									scope.QuotedTableName(),
 | 
				
			||||||
				strings.Join(columns, ","),
 | 
									strings.Join(columns, ","),
 | 
				
			||||||
				strings.Join(sqls, ","),
 | 
									strings.Join(placeholders, ","),
 | 
				
			||||||
				scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
 | 
									addExtraSpaceIfExist(extraOption),
 | 
				
			||||||
 | 
									addExtraSpaceIfExist(lastInsertIDReturningSuffix),
 | 
				
			||||||
			))
 | 
								))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// execute create sql
 | 
							// execute create sql
 | 
				
			||||||
		if scope.Dialect().SupportLastInsertId() {
 | 
							if lastInsertIDReturningSuffix == "" || primaryField == nil {
 | 
				
			||||||
			if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
 | 
								if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
				
			||||||
				id, err := result.LastInsertId()
 | 
									// set rows affected count
 | 
				
			||||||
				if scope.Err(err) == nil {
 | 
					 | 
				
			||||||
				scope.db.RowsAffected, _ = result.RowsAffected()
 | 
									scope.db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// set primary value to primary field
 | 
				
			||||||
				if primaryField != nil && primaryField.IsBlank {
 | 
									if primaryField != nil && primaryField.IsBlank {
 | 
				
			||||||
						scope.Err(scope.SetColumn(primaryField, id))
 | 
										if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
 | 
				
			||||||
 | 
											scope.Err(primaryField.Set(primaryValue))
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if primaryField == nil {
 | 
								if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
 | 
				
			||||||
				if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
 | 
					 | 
				
			||||||
					scope.db.RowsAffected, _ = results.RowsAffected()
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					scope.Err(err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
 | 
					 | 
				
			||||||
				scope.db.RowsAffected = 1
 | 
									scope.db.RowsAffected = 1
 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					scope.Err(err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ForceReloadAfterCreate(scope *Scope) {
 | 
					// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
 | 
				
			||||||
	if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
 | 
					func forceReloadAfterCreateCallback(scope *Scope) {
 | 
				
			||||||
		scope.DB().New().Select(columns.([]string)).First(scope.Value)
 | 
						if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
 | 
				
			||||||
 | 
							scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AfterCreate(scope *Scope) {
 | 
					// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("AfterCreate")
 | 
					func afterCreateCallback(scope *Scope) {
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("AfterSave")
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("AfterCreate")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("AfterSave")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:create", Create)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
 | 
					 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -2,35 +2,52 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "fmt"
 | 
					import "fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BeforeDelete(scope *Scope) {
 | 
					// Define callbacks for deleting
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("BeforeDelete")
 | 
					func init() {
 | 
				
			||||||
 | 
						DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Delete(scope *Scope) {
 | 
					// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
 | 
				
			||||||
 | 
					func beforeDeleteCallback(scope *Scope) {
 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("BeforeDelete")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
 | 
				
			||||||
 | 
					func deleteCallback(scope *Scope) {
 | 
				
			||||||
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							var extraOption string
 | 
				
			||||||
 | 
							if str, ok := scope.Get("gorm:delete_option"); ok {
 | 
				
			||||||
 | 
								extraOption = fmt.Sprint(str)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
 | 
							if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
 | 
				
			||||||
			scope.Raw(
 | 
								scope.Raw(fmt.Sprintf(
 | 
				
			||||||
				fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
 | 
									"UPDATE %v SET deleted_at=%v%v%v",
 | 
				
			||||||
				scope.QuotedTableName(),
 | 
									scope.QuotedTableName(),
 | 
				
			||||||
				scope.AddToVars(NowFunc()),
 | 
									scope.AddToVars(NowFunc()),
 | 
				
			||||||
					scope.CombinedConditionSql(),
 | 
									addExtraSpaceIfExist(scope.CombinedConditionSql()),
 | 
				
			||||||
				))
 | 
									addExtraSpaceIfExist(extraOption),
 | 
				
			||||||
 | 
								)).Exec()
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
 | 
								scope.Raw(fmt.Sprintf(
 | 
				
			||||||
 | 
									"DELETE FROM %v%v%v",
 | 
				
			||||||
 | 
									scope.QuotedTableName(),
 | 
				
			||||||
 | 
									addExtraSpaceIfExist(scope.CombinedConditionSql()),
 | 
				
			||||||
 | 
									addExtraSpaceIfExist(extraOption),
 | 
				
			||||||
 | 
								)).Exec()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		scope.Exec()
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AfterDelete(scope *Scope) {
 | 
					// afterDeleteCallback will invoke `AfterDelete` method after deleting
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("AfterDelete")
 | 
					func afterDeleteCallback(scope *Scope) {
 | 
				
			||||||
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("AfterDelete")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
 | 
					 | 
				
			||||||
	DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
 | 
					 | 
				
			||||||
	DefaultCallback.Delete().Register("gorm:delete", Delete)
 | 
					 | 
				
			||||||
	DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
 | 
					 | 
				
			||||||
	DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -6,115 +6,89 @@ import (
 | 
				
			|||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Query(scope *Scope) {
 | 
					// Define callbacks for querying
 | 
				
			||||||
	defer scope.Trace(NowFunc())
 | 
					func init() {
 | 
				
			||||||
 | 
						DefaultCallback.Query().Register("gorm:query", queryCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Query().Register("gorm:preload", preloadCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// queryCallback used to query data from database
 | 
				
			||||||
 | 
					func queryCallback(scope *Scope) {
 | 
				
			||||||
 | 
						defer scope.trace(NowFunc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		isSlice    bool
 | 
							isSlice    bool
 | 
				
			||||||
		isPtr      bool
 | 
							isPtr      bool
 | 
				
			||||||
		anyRecordFound bool
 | 
							results    = scope.IndirectValue()
 | 
				
			||||||
		destType       reflect.Type
 | 
							resultType reflect.Type
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
 | 
						if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
 | 
				
			||||||
		if primaryKey := scope.PrimaryKey(); primaryKey != "" {
 | 
							if primaryField := scope.PrimaryField(); primaryField != nil {
 | 
				
			||||||
			scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
 | 
								scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var dest = scope.IndirectValue()
 | 
					 | 
				
			||||||
	if value, ok := scope.Get("gorm:query_destination"); ok {
 | 
						if value, ok := scope.Get("gorm:query_destination"); ok {
 | 
				
			||||||
		dest = reflect.Indirect(reflect.ValueOf(value))
 | 
							results = reflect.Indirect(reflect.ValueOf(value))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if kind := dest.Kind(); kind == reflect.Slice {
 | 
						if kind := results.Kind(); kind == reflect.Slice {
 | 
				
			||||||
		isSlice = true
 | 
							isSlice = true
 | 
				
			||||||
		destType = dest.Type().Elem()
 | 
							resultType = results.Type().Elem()
 | 
				
			||||||
		dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
 | 
							results.Set(reflect.MakeSlice(results.Type(), 0, 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if destType.Kind() == reflect.Ptr {
 | 
							if resultType.Kind() == reflect.Ptr {
 | 
				
			||||||
			isPtr = true
 | 
								isPtr = true
 | 
				
			||||||
			destType = destType.Elem()
 | 
								resultType = resultType.Elem()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else if kind != reflect.Struct {
 | 
						} else if kind != reflect.Struct {
 | 
				
			||||||
		scope.Err(errors.New("unsupported destination, should be slice or struct"))
 | 
							scope.Err(errors.New("unsupported destination, should be slice or struct"))
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	scope.prepareQuerySql()
 | 
						scope.prepareQuerySQL()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
 | 
					 | 
				
			||||||
		scope.db.RowsAffected = 0
 | 
							scope.db.RowsAffected = 0
 | 
				
			||||||
 | 
							if str, ok := scope.Get("gorm:query_option"); ok {
 | 
				
			||||||
		if scope.Err(err) != nil {
 | 
								scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
				
			||||||
			defer rows.Close()
 | 
								defer rows.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			columns, _ := rows.Columns()
 | 
								columns, _ := rows.Columns()
 | 
				
			||||||
			for rows.Next() {
 | 
								for rows.Next() {
 | 
				
			||||||
				scope.db.RowsAffected++
 | 
									scope.db.RowsAffected++
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			anyRecordFound = true
 | 
									elem := results
 | 
				
			||||||
			elem := dest
 | 
					 | 
				
			||||||
				if isSlice {
 | 
									if isSlice {
 | 
				
			||||||
				elem = reflect.New(destType).Elem()
 | 
										elem = reflect.New(resultType).Elem()
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var values = make([]interface{}, len(columns))
 | 
									scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap())
 | 
				
			||||||
 | 
					 | 
				
			||||||
			fields := scope.New(elem.Addr().Interface()).Fields()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for index, column := range columns {
 | 
					 | 
				
			||||||
				if field, ok := fields[column]; ok {
 | 
					 | 
				
			||||||
					if field.Field.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
						values[index] = field.Field.Addr().Interface()
 | 
					 | 
				
			||||||
					} else {
 | 
					 | 
				
			||||||
						reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
 | 
					 | 
				
			||||||
						reflectValue.Elem().Set(field.Field.Addr())
 | 
					 | 
				
			||||||
						values[index] = reflectValue.Interface()
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					var value interface{}
 | 
					 | 
				
			||||||
					values[index] = &value
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			scope.Err(rows.Scan(values...))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for index, column := range columns {
 | 
					 | 
				
			||||||
				value := values[index]
 | 
					 | 
				
			||||||
				if field, ok := fields[column]; ok {
 | 
					 | 
				
			||||||
					if field.Field.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
						field.Field.Set(reflect.ValueOf(value).Elem())
 | 
					 | 
				
			||||||
					} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
 | 
					 | 
				
			||||||
						field.Field.Set(v)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if isSlice {
 | 
									if isSlice {
 | 
				
			||||||
					if isPtr {
 | 
										if isPtr {
 | 
				
			||||||
					dest.Set(reflect.Append(dest, elem.Addr()))
 | 
											results.Set(reflect.Append(results, elem.Addr()))
 | 
				
			||||||
					} else {
 | 
										} else {
 | 
				
			||||||
					dest.Set(reflect.Append(dest, elem))
 | 
											results.Set(reflect.Append(results, elem))
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !anyRecordFound && !isSlice {
 | 
								if scope.db.RowsAffected == 0 && !isSlice {
 | 
				
			||||||
			scope.Err(RecordNotFound)
 | 
									scope.Err(ErrRecordNotFound)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AfterQuery(scope *Scope) {
 | 
					// afterQueryCallback will invoke `AfterFind` method after querying
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("AfterFind")
 | 
					func afterQueryCallback(scope *Scope) {
 | 
				
			||||||
 | 
						if !scope.HasError() {
 | 
				
			||||||
 | 
							scope.CallMethod("AfterFind")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	DefaultCallback.Query().Register("gorm:query", Query)
 | 
					 | 
				
			||||||
	DefaultCallback.Query().Register("gorm:preload", Preload)
 | 
					 | 
				
			||||||
	DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										308
									
								
								callback_query_preload.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								callback_query_preload.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,308 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// preloadCallback used to preload associations
 | 
				
			||||||
 | 
					func preloadCallback(scope *Scope) {
 | 
				
			||||||
 | 
						if scope.Search.preload == nil || scope.HasError() {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							preloadedMap = map[string]bool{}
 | 
				
			||||||
 | 
							fields       = scope.Fields()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, preload := range scope.Search.preload {
 | 
				
			||||||
 | 
							var (
 | 
				
			||||||
 | 
								preloadFields = strings.Split(preload.schema, ".")
 | 
				
			||||||
 | 
								currentScope  = scope
 | 
				
			||||||
 | 
								currentFields = fields
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for idx, preloadField := range preloadFields {
 | 
				
			||||||
 | 
								var currentPreloadConditions []interface{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// if not preloaded
 | 
				
			||||||
 | 
								if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// assign search conditions to last preload
 | 
				
			||||||
 | 
									if idx == len(preloadFields)-1 {
 | 
				
			||||||
 | 
										currentPreloadConditions = preload.conditions
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									for _, field := range currentFields {
 | 
				
			||||||
 | 
										if field.Name != preloadField || field.Relationship == nil {
 | 
				
			||||||
 | 
											continue
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										switch field.Relationship.Kind {
 | 
				
			||||||
 | 
										case "has_one":
 | 
				
			||||||
 | 
											currentScope.handleHasOnePreload(field, currentPreloadConditions)
 | 
				
			||||||
 | 
										case "has_many":
 | 
				
			||||||
 | 
											currentScope.handleHasManyPreload(field, currentPreloadConditions)
 | 
				
			||||||
 | 
										case "belongs_to":
 | 
				
			||||||
 | 
											currentScope.handleBelongsToPreload(field, currentPreloadConditions)
 | 
				
			||||||
 | 
										case "many_to_many":
 | 
				
			||||||
 | 
											currentScope.handleManyToManyPreload(field, currentPreloadConditions)
 | 
				
			||||||
 | 
										default:
 | 
				
			||||||
 | 
											scope.Err(errors.New("unsupported relation"))
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										preloadedMap[preloadKey] = true
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if !preloadedMap[preloadKey] {
 | 
				
			||||||
 | 
										scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// preload next level
 | 
				
			||||||
 | 
								if idx < len(preloadFields)-1 {
 | 
				
			||||||
 | 
									currentScope = currentScope.getColumnAsScope(preloadField)
 | 
				
			||||||
 | 
									currentFields = currentScope.Fields()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							preloadDB         = scope.NewDB()
 | 
				
			||||||
 | 
							preloadConditions []interface{}
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, condition := range conditions {
 | 
				
			||||||
 | 
							if scopes, ok := condition.(func(*DB) *DB); ok {
 | 
				
			||||||
 | 
								preloadDB = scopes(preloadDB)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								preloadConditions = append(preloadConditions, condition)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return preloadDB, preloadConditions
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// handleHasOnePreload used to preload has one associations
 | 
				
			||||||
 | 
					func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
 | 
				
			||||||
 | 
						relation := field.Relationship
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// get relations's primary keys
 | 
				
			||||||
 | 
						primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
 | 
				
			||||||
 | 
						if len(primaryKeys) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// preload conditions
 | 
				
			||||||
 | 
						preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// find relations
 | 
				
			||||||
 | 
						results := makeSlice(field.Struct.Type)
 | 
				
			||||||
 | 
						scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// assign find results
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							resultsValue       = indirect(reflect.ValueOf(results))
 | 
				
			||||||
 | 
							indirectScopeValue = scope.IndirectValue()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < resultsValue.Len(); i++ {
 | 
				
			||||||
 | 
							result := resultsValue.Index(i)
 | 
				
			||||||
 | 
							if indirectScopeValue.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
								foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
 | 
				
			||||||
 | 
								for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
				
			||||||
 | 
									if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
 | 
				
			||||||
 | 
										indirectValue.FieldByName(field.Name).Set(result)
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								scope.Err(field.Set(result))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// handleHasManyPreload used to preload has many associations
 | 
				
			||||||
 | 
					func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
 | 
				
			||||||
 | 
						relation := field.Relationship
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// get relations's primary keys
 | 
				
			||||||
 | 
						primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
 | 
				
			||||||
 | 
						if len(primaryKeys) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// preload conditions
 | 
				
			||||||
 | 
						preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// find relations
 | 
				
			||||||
 | 
						results := makeSlice(field.Struct.Type)
 | 
				
			||||||
 | 
						scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// assign find results
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							resultsValue       = indirect(reflect.ValueOf(results))
 | 
				
			||||||
 | 
							indirectScopeValue = scope.IndirectValue()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if indirectScopeValue.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
							for i := 0; i < resultsValue.Len(); i++ {
 | 
				
			||||||
 | 
								result := resultsValue.Index(i)
 | 
				
			||||||
 | 
								foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
 | 
				
			||||||
 | 
								for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
				
			||||||
 | 
									object := indirect(indirectScopeValue.Index(j))
 | 
				
			||||||
 | 
									if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
 | 
				
			||||||
 | 
										objectField := object.FieldByName(field.Name)
 | 
				
			||||||
 | 
										objectField.Set(reflect.Append(objectField, result))
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							scope.Err(field.Set(resultsValue))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// handleBelongsToPreload used to preload belongs to associations
 | 
				
			||||||
 | 
					func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
 | 
				
			||||||
 | 
						relation := field.Relationship
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// preload conditions
 | 
				
			||||||
 | 
						preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// get relations's primary keys
 | 
				
			||||||
 | 
						primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
 | 
				
			||||||
 | 
						if len(primaryKeys) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// find relations
 | 
				
			||||||
 | 
						results := makeSlice(field.Struct.Type)
 | 
				
			||||||
 | 
						scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// assign find results
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							resultsValue       = indirect(reflect.ValueOf(results))
 | 
				
			||||||
 | 
							indirectScopeValue = scope.IndirectValue()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < resultsValue.Len(); i++ {
 | 
				
			||||||
 | 
							result := resultsValue.Index(i)
 | 
				
			||||||
 | 
							if indirectScopeValue.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
								value := getValueFromFields(result, relation.AssociationForeignFieldNames)
 | 
				
			||||||
 | 
								for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
				
			||||||
 | 
									object := indirect(indirectScopeValue.Index(j))
 | 
				
			||||||
 | 
									if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
 | 
				
			||||||
 | 
										object.FieldByName(field.Name).Set(result)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								scope.Err(field.Set(result))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// handleManyToManyPreload used to preload many to many associations
 | 
				
			||||||
 | 
					func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							relation         = field.Relationship
 | 
				
			||||||
 | 
							joinTableHandler = relation.JoinTableHandler
 | 
				
			||||||
 | 
							fieldType        = field.Struct.Type.Elem()
 | 
				
			||||||
 | 
							foreignKeyValue  interface{}
 | 
				
			||||||
 | 
							foreignKeyType   = reflect.ValueOf(&foreignKeyValue).Type()
 | 
				
			||||||
 | 
							linkHash         = map[string][]reflect.Value{}
 | 
				
			||||||
 | 
							isPtr            bool
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if fieldType.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
							isPtr = true
 | 
				
			||||||
 | 
							fieldType = fieldType.Elem()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var sourceKeys = []string{}
 | 
				
			||||||
 | 
						for _, key := range joinTableHandler.SourceForeignKeys() {
 | 
				
			||||||
 | 
							sourceKeys = append(sourceKeys, key.DBName)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// preload conditions
 | 
				
			||||||
 | 
						preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// generate query with join table
 | 
				
			||||||
 | 
						newScope := scope.New(reflect.New(fieldType).Interface())
 | 
				
			||||||
 | 
						preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
 | 
				
			||||||
 | 
						preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// preload inline conditions
 | 
				
			||||||
 | 
						if len(preloadConditions) > 0 {
 | 
				
			||||||
 | 
							preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rows, err := preloadDB.Rows()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if scope.Err(err) != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rows.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						columns, _ := rows.Columns()
 | 
				
			||||||
 | 
						for rows.Next() {
 | 
				
			||||||
 | 
							var (
 | 
				
			||||||
 | 
								elem   = reflect.New(fieldType).Elem()
 | 
				
			||||||
 | 
								fields = scope.New(elem.Addr().Interface()).fieldsMap()
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// register foreign keys in join tables
 | 
				
			||||||
 | 
							for _, sourceKey := range sourceKeys {
 | 
				
			||||||
 | 
								fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							scope.scan(rows, columns, fields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// generate hashed forkey keys in join table
 | 
				
			||||||
 | 
							var foreignKeys = make([]interface{}, len(sourceKeys))
 | 
				
			||||||
 | 
							for idx, sourceKey := range sourceKeys {
 | 
				
			||||||
 | 
								foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							hashedSourceKeys := toString(foreignKeys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if isPtr {
 | 
				
			||||||
 | 
								linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// assign find results
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							indirectScopeValue = scope.IndirectValue()
 | 
				
			||||||
 | 
							fieldsSourceMap    = map[string]reflect.Value{}
 | 
				
			||||||
 | 
							foreignFieldNames  = []string{}
 | 
				
			||||||
 | 
							fields             = scope.fieldsMap()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, dbName := range relation.ForeignFieldNames {
 | 
				
			||||||
 | 
							if field, ok := fields[dbName]; ok {
 | 
				
			||||||
 | 
								foreignFieldNames = append(foreignFieldNames, field.Name)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if indirectScopeValue.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
							for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
				
			||||||
 | 
								object := indirect(indirectScopeValue.Index(j))
 | 
				
			||||||
 | 
								fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if indirectScopeValue.IsValid() {
 | 
				
			||||||
 | 
							fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for source, link := range linkHash {
 | 
				
			||||||
 | 
							fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -2,15 +2,15 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "reflect"
 | 
					import "reflect"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BeginTransaction(scope *Scope) {
 | 
					func beginTransactionCallback(scope *Scope) {
 | 
				
			||||||
	scope.Begin()
 | 
						scope.Begin()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func CommitOrRollbackTransaction(scope *Scope) {
 | 
					func commitOrRollbackTransactionCallback(scope *Scope) {
 | 
				
			||||||
	scope.CommitOrRollback()
 | 
						scope.CommitOrRollback()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SaveBeforeAssociations(scope *Scope) {
 | 
					func saveBeforeAssociationsCallback(scope *Scope) {
 | 
				
			||||||
	if !scope.shouldSaveAssociations() {
 | 
						if !scope.shouldSaveAssociations() {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SaveAfterAssociations(scope *Scope) {
 | 
					func saveAfterAssociationsCallback(scope *Scope) {
 | 
				
			||||||
	if !scope.shouldSaveAssociations() {
 | 
						if !scope.shouldSaveAssociations() {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -23,7 +23,7 @@ func afterCreate1(s *Scope)  {}
 | 
				
			|||||||
func afterCreate2(s *Scope)  {}
 | 
					func afterCreate2(s *Scope)  {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRegisterCallback(t *testing.T) {
 | 
					func TestRegisterCallback(t *testing.T) {
 | 
				
			||||||
	var callback = &callback{processors: []*callbackProcessor{}}
 | 
						var callback = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callback.Create().Register("before_create1", beforeCreate1)
 | 
						callback.Create().Register("before_create1", beforeCreate1)
 | 
				
			||||||
	callback.Create().Register("before_create2", beforeCreate2)
 | 
						callback.Create().Register("before_create2", beforeCreate2)
 | 
				
			||||||
@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRegisterCallbackWithOrder(t *testing.T) {
 | 
					func TestRegisterCallbackWithOrder(t *testing.T) {
 | 
				
			||||||
	var callback1 = &callback{processors: []*callbackProcessor{}}
 | 
						var callback1 = &Callback{}
 | 
				
			||||||
	callback1.Create().Register("before_create1", beforeCreate1)
 | 
						callback1.Create().Register("before_create1", beforeCreate1)
 | 
				
			||||||
	callback1.Create().Register("create", create)
 | 
						callback1.Create().Register("create", create)
 | 
				
			||||||
	callback1.Create().Register("after_create1", afterCreate1)
 | 
						callback1.Create().Register("after_create1", afterCreate1)
 | 
				
			||||||
@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("register callback with order")
 | 
							t.Errorf("register callback with order")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var callback2 = &callback{processors: []*callbackProcessor{}}
 | 
						var callback2 = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callback2.Update().Register("create", create)
 | 
						callback2.Update().Register("create", create)
 | 
				
			||||||
	callback2.Update().Before("create").Register("before_create1", beforeCreate1)
 | 
						callback2.Update().Before("create").Register("before_create1", beforeCreate1)
 | 
				
			||||||
@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | 
					func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | 
				
			||||||
	var callback1 = &callback{processors: []*callbackProcessor{}}
 | 
						var callback1 = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
 | 
						callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
 | 
				
			||||||
	callback1.Query().Register("before_create1", beforeCreate1)
 | 
						callback1.Query().Register("before_create1", beforeCreate1)
 | 
				
			||||||
@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("register callback with order")
 | 
							t.Errorf("register callback with order")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var callback2 = &callback{processors: []*callbackProcessor{}}
 | 
						var callback2 = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
 | 
						callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
 | 
				
			||||||
	callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
 | 
						callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
 | 
				
			||||||
@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | 
				
			|||||||
func replaceCreate(s *Scope) {}
 | 
					func replaceCreate(s *Scope) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestReplaceCallback(t *testing.T) {
 | 
					func TestReplaceCallback(t *testing.T) {
 | 
				
			||||||
	var callback = &callback{processors: []*callbackProcessor{}}
 | 
						var callback = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callback.Create().Before("after_create1").After("before_create1").Register("create", create)
 | 
						callback.Create().Before("after_create1").After("before_create1").Register("create", create)
 | 
				
			||||||
	callback.Create().Register("before_create1", beforeCreate1)
 | 
						callback.Create().Register("before_create1", beforeCreate1)
 | 
				
			||||||
@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRemoveCallback(t *testing.T) {
 | 
					func TestRemoveCallback(t *testing.T) {
 | 
				
			||||||
	var callback = &callback{processors: []*callbackProcessor{}}
 | 
						var callback = &Callback{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	callback.Create().Before("after_create1").After("before_create1").Register("create", create)
 | 
						callback.Create().Before("after_create1").After("before_create1").Register("create", create)
 | 
				
			||||||
	callback.Create().Register("before_create1", beforeCreate1)
 | 
						callback.Create().Register("before_create1", beforeCreate1)
 | 
				
			||||||
 | 
				
			|||||||
@ -5,91 +5,102 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AssignUpdateAttributes(scope *Scope) {
 | 
					// Define callbacks for updating
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:update", updateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// assignUpdatingAttributesCallback assign updating attributes to model
 | 
				
			||||||
 | 
					func assignUpdatingAttributesCallback(scope *Scope) {
 | 
				
			||||||
	if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
 | 
						if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
 | 
				
			||||||
		if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
 | 
							if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
 | 
				
			||||||
			protected, ok := scope.Get("gorm:ignore_protected_attrs")
 | 
								if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
 | 
				
			||||||
			_, updateColumn := scope.Get("gorm:update_column")
 | 
									scope.InstanceSet("gorm:update_attrs", updateMaps)
 | 
				
			||||||
			updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
 | 
								} else {
 | 
				
			||||||
 | 
					 | 
				
			||||||
			if updateColumn {
 | 
					 | 
				
			||||||
				scope.InstanceSet("gorm:update_attrs", maps)
 | 
					 | 
				
			||||||
			} else if len(updateAttrs) > 0 {
 | 
					 | 
				
			||||||
				scope.InstanceSet("gorm:update_attrs", updateAttrs)
 | 
					 | 
				
			||||||
			} else if !hasUpdate {
 | 
					 | 
				
			||||||
				scope.SkipLeft()
 | 
									scope.SkipLeft()
 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BeforeUpdate(scope *Scope) {
 | 
					// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
 | 
				
			||||||
 | 
					func beforeUpdateCallback(scope *Scope) {
 | 
				
			||||||
	if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
						if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
				
			||||||
		scope.CallMethodWithErrorCheck("BeforeSave")
 | 
							if !scope.HasError() {
 | 
				
			||||||
		scope.CallMethodWithErrorCheck("BeforeUpdate")
 | 
								scope.CallMethod("BeforeSave")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !scope.HasError() {
 | 
				
			||||||
 | 
								scope.CallMethod("BeforeUpdate")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func UpdateTimeStampWhenUpdate(scope *Scope) {
 | 
					// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
 | 
				
			||||||
 | 
					func updateTimeStampForUpdateCallback(scope *Scope) {
 | 
				
			||||||
	if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
						if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
				
			||||||
		scope.SetColumn("UpdatedAt", NowFunc())
 | 
							scope.SetColumn("UpdatedAt", NowFunc())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Update(scope *Scope) {
 | 
					// updateCallback the callback used to update data to database
 | 
				
			||||||
 | 
					func updateCallback(scope *Scope) {
 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		var sqls []string
 | 
							var sqls []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
							if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
				
			||||||
			for key, value := range updateAttrs.(map[string]interface{}) {
 | 
								for column, value := range updateAttrs.(map[string]interface{}) {
 | 
				
			||||||
				if scope.changeableDBColumn(key) {
 | 
									sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
 | 
				
			||||||
					sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			fields := scope.Fields()
 | 
								for _, field := range scope.Fields() {
 | 
				
			||||||
			for _, field := range fields {
 | 
									if scope.changeableField(field) {
 | 
				
			||||||
				if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
 | 
										if !field.IsPrimaryKey && field.IsNormal {
 | 
				
			||||||
						sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 | 
											sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 | 
				
			||||||
					} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
										} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
				
			||||||
					for _, dbName := range relationship.ForeignDBNames {
 | 
											for _, foreignKey := range relationship.ForeignDBNames {
 | 
				
			||||||
						if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
 | 
												if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
 | 
				
			||||||
							sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
 | 
													sqls = append(sqls,
 | 
				
			||||||
							sqls = append(sqls, sql)
 | 
														fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var extraOption string
 | 
				
			||||||
 | 
							if str, ok := scope.Get("gorm:update_option"); ok {
 | 
				
			||||||
 | 
								extraOption = fmt.Sprint(str)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(sqls) > 0 {
 | 
							if len(sqls) > 0 {
 | 
				
			||||||
			scope.Raw(fmt.Sprintf(
 | 
								scope.Raw(fmt.Sprintf(
 | 
				
			||||||
				"UPDATE %v SET %v %v",
 | 
									"UPDATE %v SET %v%v%v",
 | 
				
			||||||
				scope.QuotedTableName(),
 | 
									scope.QuotedTableName(),
 | 
				
			||||||
				strings.Join(sqls, ", "),
 | 
									strings.Join(sqls, ", "),
 | 
				
			||||||
				scope.CombinedConditionSql(),
 | 
									addExtraSpaceIfExist(scope.CombinedConditionSql()),
 | 
				
			||||||
			))
 | 
									addExtraSpaceIfExist(extraOption),
 | 
				
			||||||
			scope.Exec()
 | 
								)).Exec()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AfterUpdate(scope *Scope) {
 | 
					// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
 | 
				
			||||||
 | 
					func afterUpdateCallback(scope *Scope) {
 | 
				
			||||||
	if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
						if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
				
			||||||
		scope.CallMethodWithErrorCheck("AfterUpdate")
 | 
							if !scope.HasError() {
 | 
				
			||||||
		scope.CallMethodWithErrorCheck("AfterSave")
 | 
								scope.CallMethod("AfterUpdate")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !scope.HasError() {
 | 
				
			||||||
 | 
								scope.CallMethod("AfterSave")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:update", Update)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
 | 
					 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,117 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type commonDialect struct{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) BinVar(i int) string {
 | 
					 | 
				
			||||||
	return "$$" // ?
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) SupportLastInsertId() bool {
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) HasTop() bool {
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
					 | 
				
			||||||
	switch value.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Bool:
 | 
					 | 
				
			||||||
		return "BOOLEAN"
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "INTEGER AUTO_INCREMENT"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "INTEGER"
 | 
					 | 
				
			||||||
	case reflect.Int64, reflect.Uint64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "BIGINT AUTO_INCREMENT"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "BIGINT"
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return "FLOAT"
 | 
					 | 
				
			||||||
	case reflect.String:
 | 
					 | 
				
			||||||
		if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("VARCHAR(%d)", size)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "VARCHAR(65532)"
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().(time.Time); ok {
 | 
					 | 
				
			||||||
			return "TIMESTAMP"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().([]byte); ok {
 | 
					 | 
				
			||||||
			if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
				return fmt.Sprintf("BINARY(%d)", size)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return "BINARY(65532)"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) ReturningStr(tableName, key string) string {
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) SelectFromDummyTable() string {
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) Quote(key string) string {
 | 
					 | 
				
			||||||
	return fmt.Sprintf(`"%s"`, key)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		count        int
 | 
					 | 
				
			||||||
		databaseName = c.CurrentDatabase(scope)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		count        int
 | 
					 | 
				
			||||||
		databaseName = c.CurrentDatabase(scope)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		count        int
 | 
					 | 
				
			||||||
		databaseName = c.CurrentDatabase(scope)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// RawScanInt scans the first column of the first row into the `scan' int pointer.
 | 
					 | 
				
			||||||
// This function captures raw query errors and propagates them to the original scope.
 | 
					 | 
				
			||||||
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// RawScanString scans the first column of the first row into the `scan' string pointer.
 | 
					 | 
				
			||||||
// This function captures raw query errors and propagates them to the original scope.
 | 
					 | 
				
			||||||
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
 | 
				
			|||||||
	DB.AutoMigrate(&CustomizeColumn{})
 | 
						DB.AutoMigrate(&CustomizeColumn{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	scope := DB.NewScope(&CustomizeColumn{})
 | 
						scope := DB.NewScope(&CustomizeColumn{})
 | 
				
			||||||
	if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
 | 
						if !scope.Dialect().HasColumn(scope.TableName(), col) {
 | 
				
			||||||
		t.Errorf("CustomizeColumn should have column %s", col)
 | 
							t.Errorf("CustomizeColumn should have column %s", col)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DB.HasTable("foobarbaz")
 | 
						if err := DB.Find(&User{}).Error; err == nil {
 | 
				
			||||||
	if DB.Error == nil {
 | 
					 | 
				
			||||||
		t.Errorf("Expected operation on closed db to produce an error, but err was nil")
 | 
							t.Errorf("Expected operation on closed db to produce an error, but err was nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -45,7 +45,7 @@ func TestSoftDelete(t *testing.T) {
 | 
				
			|||||||
	type User struct {
 | 
						type User struct {
 | 
				
			||||||
		Id        int64
 | 
							Id        int64
 | 
				
			||||||
		Name      string
 | 
							Name      string
 | 
				
			||||||
		DeletedAt time.Time
 | 
							DeletedAt *time.Time
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	DB.AutoMigrate(&User{})
 | 
						DB.AutoMigrate(&User{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										115
									
								
								dialect.go
									
									
									
									
									
								
							
							
						
						
									
										115
									
								
								dialect.go
									
									
									
									
									
								
							@ -1,41 +1,100 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Dialect interface contains behaviors that differ across SQL database
 | 
				
			||||||
type Dialect interface {
 | 
					type Dialect interface {
 | 
				
			||||||
	BinVar(i int) string
 | 
						// GetName get dialect's name
 | 
				
			||||||
	SupportLastInsertId() bool
 | 
						GetName() string
 | 
				
			||||||
	HasTop() bool
 | 
					
 | 
				
			||||||
	SqlTag(value reflect.Value, size int, autoIncrease bool) string
 | 
						// SetDB set db for dialect
 | 
				
			||||||
	ReturningStr(tableName, key string) string
 | 
						SetDB(db *sql.DB)
 | 
				
			||||||
	SelectFromDummyTable() string
 | 
					
 | 
				
			||||||
 | 
						// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
 | 
				
			||||||
 | 
						BindVar(i int) string
 | 
				
			||||||
 | 
						// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
 | 
				
			||||||
	Quote(key string) string
 | 
						Quote(key string) string
 | 
				
			||||||
	HasTable(scope *Scope, tableName string) bool
 | 
						// DataTypeOf return data's sql type
 | 
				
			||||||
	HasColumn(scope *Scope, tableName string, columnName string) bool
 | 
						DataTypeOf(field *StructField) string
 | 
				
			||||||
	HasIndex(scope *Scope, tableName string, indexName string) bool
 | 
					
 | 
				
			||||||
	RemoveIndex(scope *Scope, indexName string)
 | 
						// HasIndex check has index or not
 | 
				
			||||||
	CurrentDatabase(scope *Scope) string
 | 
						HasIndex(tableName string, indexName string) bool
 | 
				
			||||||
 | 
						// HasForeignKey check has foreign key or not
 | 
				
			||||||
 | 
						HasForeignKey(tableName string, foreignKeyName string) bool
 | 
				
			||||||
 | 
						// RemoveIndex remove index
 | 
				
			||||||
 | 
						RemoveIndex(tableName string, indexName string) error
 | 
				
			||||||
 | 
						// HasTable check has table or not
 | 
				
			||||||
 | 
						HasTable(tableName string) bool
 | 
				
			||||||
 | 
						// HasColumn check has column or not
 | 
				
			||||||
 | 
						HasColumn(tableName string, columnName string) bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
 | 
				
			||||||
 | 
						LimitAndOffsetSQL(limit, offset int) string
 | 
				
			||||||
 | 
						// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
 | 
				
			||||||
 | 
						SelectFromDummyTable() string
 | 
				
			||||||
 | 
						// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
 | 
				
			||||||
 | 
						LastInsertIDReturningSuffix(tableName, columnName string) string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewDialect(driver string) Dialect {
 | 
					var dialectsMap = map[string]Dialect{}
 | 
				
			||||||
	var d Dialect
 | 
					
 | 
				
			||||||
	switch driver {
 | 
					func newDialect(name string, db *sql.DB) Dialect {
 | 
				
			||||||
	case "postgres":
 | 
						if value, ok := dialectsMap[name]; ok {
 | 
				
			||||||
		d = &postgres{}
 | 
							dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
 | 
				
			||||||
	case "foundation":
 | 
							dialect.SetDB(db)
 | 
				
			||||||
		d = &foundation{}
 | 
							return dialect
 | 
				
			||||||
	case "mysql":
 | 
					 | 
				
			||||||
		d = &mysql{}
 | 
					 | 
				
			||||||
	case "sqlite3":
 | 
					 | 
				
			||||||
		d = &sqlite3{}
 | 
					 | 
				
			||||||
	case "mssql":
 | 
					 | 
				
			||||||
		d = &mssql{}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
 | 
					 | 
				
			||||||
		d = &commonDialect{}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return d
 | 
					
 | 
				
			||||||
 | 
						fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
 | 
				
			||||||
 | 
						commontDialect := &commonDialect{}
 | 
				
			||||||
 | 
						commontDialect.SetDB(db)
 | 
				
			||||||
 | 
						return commontDialect
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RegisterDialect register new dialect
 | 
				
			||||||
 | 
					func RegisterDialect(name string, dialect Dialect) {
 | 
				
			||||||
 | 
						dialectsMap[name] = dialect
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ParseFieldStructForDialect parse field struct for dialect
 | 
				
			||||||
 | 
					func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
 | 
				
			||||||
 | 
						// Get redirected field type
 | 
				
			||||||
 | 
						var reflectType = field.Struct.Type
 | 
				
			||||||
 | 
						for reflectType.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
							reflectType = reflectType.Elem()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get redirected field value
 | 
				
			||||||
 | 
						fieldValue = reflect.Indirect(reflect.New(reflectType))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get scanner's real value
 | 
				
			||||||
 | 
						var getScannerValue func(reflect.Value)
 | 
				
			||||||
 | 
						getScannerValue = func(value reflect.Value) {
 | 
				
			||||||
 | 
							fieldValue = value
 | 
				
			||||||
 | 
							if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
 | 
				
			||||||
 | 
								getScannerValue(fieldValue.Field(0))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						getScannerValue(fieldValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Default Size
 | 
				
			||||||
 | 
						if num, ok := field.TagSettings["SIZE"]; ok {
 | 
				
			||||||
 | 
							size, _ = strconv.Atoi(num)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							size = 255
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Default type from tag setting
 | 
				
			||||||
 | 
						additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
 | 
				
			||||||
 | 
						if value, ok := field.TagSettings["DEFAULT"]; ok {
 | 
				
			||||||
 | 
							additionalType = additionalType + " DEFAULT " + value
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										137
									
								
								dialect_common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								dialect_common.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type commonDialect struct {
 | 
				
			||||||
 | 
						db *sql.DB
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						RegisterDialect("common", &commonDialect{})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) GetName() string {
 | 
				
			||||||
 | 
						return "common"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *commonDialect) SetDB(db *sql.DB) {
 | 
				
			||||||
 | 
						s.db = db
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) BindVar(i int) string {
 | 
				
			||||||
 | 
						return "$$" // ?
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) Quote(key string) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf(`"%s"`, key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) DataTypeOf(field *StructField) string {
 | 
				
			||||||
 | 
						var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							switch dataValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Bool:
 | 
				
			||||||
 | 
								sqlType = "BOOLEAN"
 | 
				
			||||||
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
 | 
				
			||||||
 | 
									sqlType = "INTEGER AUTO_INCREMENT"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "INTEGER"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Int64, reflect.Uint64:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
 | 
				
			||||||
 | 
									sqlType = "BIGINT AUTO_INCREMENT"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "BIGINT"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Float32, reflect.Float64:
 | 
				
			||||||
 | 
								sqlType = "FLOAT"
 | 
				
			||||||
 | 
							case reflect.String:
 | 
				
			||||||
 | 
								if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
									sqlType = fmt.Sprintf("VARCHAR(%d)", size)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "VARCHAR(65532)"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Struct:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().(time.Time); ok {
 | 
				
			||||||
 | 
									sqlType = "TIMESTAMP"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().([]byte); ok {
 | 
				
			||||||
 | 
									if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
										sqlType = fmt.Sprintf("BINARY(%d)", size)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										sqlType = "BINARY(65532)"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if strings.TrimSpace(additionalType) == "" {
 | 
				
			||||||
 | 
							return sqlType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s commonDialect) HasIndex(tableName string, indexName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
 | 
				
			||||||
 | 
						_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s commonDialect) HasTable(tableName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s commonDialect) HasColumn(tableName string, columnName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s commonDialect) currentDatabase() (name string) {
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
 | 
				
			||||||
 | 
						if limit > 0 || offset > 0 {
 | 
				
			||||||
 | 
							if limit >= 0 {
 | 
				
			||||||
 | 
								sql += fmt.Sprintf(" LIMIT %d", limit)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if offset >= 0 {
 | 
				
			||||||
 | 
								sql += fmt.Sprintf(" OFFSET %d", offset)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) SelectFromDummyTable() string {
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										113
									
								
								dialect_mysql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								dialect_mysql.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,113 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type mysql struct {
 | 
				
			||||||
 | 
						commonDialect
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						RegisterDialect("mysql", &mysql{})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mysql) GetName() string {
 | 
				
			||||||
 | 
						return "mysql"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mysql) Quote(key string) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("`%s`", key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get Data Type for MySQL Dialect
 | 
				
			||||||
 | 
					func (mysql) DataTypeOf(field *StructField) string {
 | 
				
			||||||
 | 
						var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							switch dataValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Bool:
 | 
				
			||||||
 | 
								sqlType = "boolean"
 | 
				
			||||||
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "int AUTO_INCREMENT"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "int"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "int unsigned AUTO_INCREMENT"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "int unsigned"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Int64:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "bigint AUTO_INCREMENT"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "bigint"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Uint64:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "bigint unsigned AUTO_INCREMENT"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "bigint unsigned"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Float32, reflect.Float64:
 | 
				
			||||||
 | 
								sqlType = "double"
 | 
				
			||||||
 | 
							case reflect.String:
 | 
				
			||||||
 | 
								if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
									sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "longtext"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Struct:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().(time.Time); ok {
 | 
				
			||||||
 | 
									if _, ok := field.TagSettings["NOT NULL"]; ok {
 | 
				
			||||||
 | 
										sqlType = "timestamp"
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										sqlType = "timestamp NULL"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().([]byte); ok {
 | 
				
			||||||
 | 
									if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
										sqlType = fmt.Sprintf("varbinary(%d)", size)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										sqlType = "longblob"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if strings.TrimSpace(additionalType) == "" {
 | 
				
			||||||
 | 
							return sqlType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mysql) RemoveIndex(tableName string, indexName string) error {
 | 
				
			||||||
 | 
						_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), foreignKeyName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mysql) currentDatabase() (name string) {
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mysql) SelectFromDummyTable() string {
 | 
				
			||||||
 | 
						return "FROM DUAL"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										128
									
								
								dialect_postgres.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								dialect_postgres.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,128 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type postgres struct {
 | 
				
			||||||
 | 
						commonDialect
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						RegisterDialect("postgres", &postgres{})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (postgres) GetName() string {
 | 
				
			||||||
 | 
						return "postgres"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (postgres) BindVar(i int) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("$%v", i)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (postgres) DataTypeOf(field *StructField) string {
 | 
				
			||||||
 | 
						var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							switch dataValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Bool:
 | 
				
			||||||
 | 
								sqlType = "boolean"
 | 
				
			||||||
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "serial"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "integer"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Int64, reflect.Uint64:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "bigserial"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "bigint"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Float32, reflect.Float64:
 | 
				
			||||||
 | 
								sqlType = "numeric"
 | 
				
			||||||
 | 
							case reflect.String:
 | 
				
			||||||
 | 
								if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
									sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "text"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Struct:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().(time.Time); ok {
 | 
				
			||||||
 | 
									sqlType = "timestamp with time zone"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Map:
 | 
				
			||||||
 | 
								if dataValue.Type().Name() == "Hstore" {
 | 
				
			||||||
 | 
									sqlType = "hstore"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								if isByteArrayOrSlice(dataValue) {
 | 
				
			||||||
 | 
									sqlType = "bytea"
 | 
				
			||||||
 | 
								} else if isUUID(dataValue) {
 | 
				
			||||||
 | 
									sqlType = "uuid"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if strings.TrimSpace(additionalType) == "" {
 | 
				
			||||||
 | 
							return sqlType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s postgres) HasIndex(tableName string, indexName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s postgres) HasTable(tableName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s postgres) HasColumn(tableName string, columnName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s postgres) currentDatabase() (name string) {
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("RETURNING %v.%v", tableName, key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (postgres) SupportLastInsertID() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func isByteArrayOrSlice(value reflect.Value) bool {
 | 
				
			||||||
 | 
						return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func isUUID(value reflect.Value) bool {
 | 
				
			||||||
 | 
						if value.Kind() != reflect.Array || value.Type().Len() != 16 {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						typename := value.Type().Name()
 | 
				
			||||||
 | 
						lower := strings.ToLower(typename)
 | 
				
			||||||
 | 
						return "uuid" == lower || "guid" == lower
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										106
									
								
								dialect_sqlite3.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								dialect_sqlite3.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,106 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type sqlite3 struct {
 | 
				
			||||||
 | 
						commonDialect
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						RegisterDialect("sqlite", &sqlite3{})
 | 
				
			||||||
 | 
						RegisterDialect("sqlite3", &sqlite3{})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (sqlite3) GetName() string {
 | 
				
			||||||
 | 
						return "sqlite3"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get Data Type for Sqlite Dialect
 | 
				
			||||||
 | 
					func (sqlite3) DataTypeOf(field *StructField) string {
 | 
				
			||||||
 | 
						var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							switch dataValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Bool:
 | 
				
			||||||
 | 
								sqlType = "bool"
 | 
				
			||||||
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
 | 
								if field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "integer primary key autoincrement"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "integer"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Int64, reflect.Uint64:
 | 
				
			||||||
 | 
								if field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "integer primary key autoincrement"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "bigint"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Float32, reflect.Float64:
 | 
				
			||||||
 | 
								sqlType = "real"
 | 
				
			||||||
 | 
							case reflect.String:
 | 
				
			||||||
 | 
								if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
									sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "text"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Struct:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().(time.Time); ok {
 | 
				
			||||||
 | 
									sqlType = "datetime"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().([]byte); ok {
 | 
				
			||||||
 | 
									sqlType = "blob"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if strings.TrimSpace(additionalType) == "" {
 | 
				
			||||||
 | 
							return sqlType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s sqlite3) HasIndex(tableName string, indexName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s sqlite3) HasTable(tableName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s sqlite3) HasColumn(tableName string, columnName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s sqlite3) currentDatabase() (name string) {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							ifaces   = make([]interface{}, 3)
 | 
				
			||||||
 | 
							pointers = make([]*string, 3)
 | 
				
			||||||
 | 
							i        int
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						for i = 0; i < 3; i++ {
 | 
				
			||||||
 | 
							ifaces[i] = &pointers[i]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if pointers[1] != nil {
 | 
				
			||||||
 | 
							name = *pointers[1]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										150
									
								
								dialects/mssql/mssql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								dialects/mssql/mssql.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,150 @@
 | 
				
			|||||||
 | 
					package mssql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_ "github.com/denisenkom/go-mssqldb"
 | 
				
			||||||
 | 
						"github.com/jinzhu/gorm"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func setIdentityInsert(scope *gorm.Scope) {
 | 
				
			||||||
 | 
						if scope.Dialect().GetName() == "mssql" {
 | 
				
			||||||
 | 
							scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
 | 
				
			||||||
 | 
						gorm.RegisterDialect("mssql", &mssql{})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type mssql struct {
 | 
				
			||||||
 | 
						db *sql.DB
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) GetName() string {
 | 
				
			||||||
 | 
						return "mssql"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *mssql) SetDB(db *sql.DB) {
 | 
				
			||||||
 | 
						s.db = db
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) BindVar(i int) string {
 | 
				
			||||||
 | 
						return "$$" // ?
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) Quote(key string) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf(`"%s"`, key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) DataTypeOf(field *gorm.StructField) string {
 | 
				
			||||||
 | 
						var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							switch dataValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Bool:
 | 
				
			||||||
 | 
								sqlType = "bit"
 | 
				
			||||||
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "int IDENTITY(1,1)"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "int"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Int64, reflect.Uint64:
 | 
				
			||||||
 | 
								if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
 | 
				
			||||||
 | 
									sqlType = "bigint IDENTITY(1,1)"
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "bigint"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Float32, reflect.Float64:
 | 
				
			||||||
 | 
								sqlType = "float"
 | 
				
			||||||
 | 
							case reflect.String:
 | 
				
			||||||
 | 
								if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
									sqlType = fmt.Sprintf("nvarchar(%d)", size)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									sqlType = "text"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Struct:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().(time.Time); ok {
 | 
				
			||||||
 | 
									sqlType = "datetime2"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								if _, ok := dataValue.Interface().([]byte); ok {
 | 
				
			||||||
 | 
									if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
										sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										sqlType = "text"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlType == "" {
 | 
				
			||||||
 | 
							panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if strings.TrimSpace(additionalType) == "" {
 | 
				
			||||||
 | 
							return sqlType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mssql) HasIndex(tableName string, indexName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mssql) RemoveIndex(tableName string, indexName string) error {
 | 
				
			||||||
 | 
						_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mssql) HasTable(tableName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mssql) HasColumn(tableName string, columnName string) bool {
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
 | 
				
			||||||
 | 
						return count > 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s mssql) currentDatabase() (name string) {
 | 
				
			||||||
 | 
						s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
 | 
				
			||||||
 | 
						if limit > 0 || offset > 0 {
 | 
				
			||||||
 | 
							if offset < 0 {
 | 
				
			||||||
 | 
								offset = 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							sql += fmt.Sprintf(" OFFSET %d ROWS", offset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if limit >= 0 {
 | 
				
			||||||
 | 
								sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) SelectFromDummyTable() string {
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										3
									
								
								dialects/mysql/mysql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								dialects/mysql/mysql.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					package mysql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import _ "github.com/go-sql-driver/mysql"
 | 
				
			||||||
							
								
								
									
										52
									
								
								dialects/postgres/postgres.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								dialects/postgres/postgres.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					package postgres
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"database/sql/driver"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_ "github.com/lib/pq"
 | 
				
			||||||
 | 
						"github.com/lib/pq/hstore"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Hstore map[string]*string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h Hstore) Value() (driver.Value, error) {
 | 
				
			||||||
 | 
						hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
 | 
				
			||||||
 | 
						if len(h) == 0 {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for key, value := range h {
 | 
				
			||||||
 | 
							var s sql.NullString
 | 
				
			||||||
 | 
							if value != nil {
 | 
				
			||||||
 | 
								s.String = *value
 | 
				
			||||||
 | 
								s.Valid = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							hstore.Map[key] = s
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return hstore.Value()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Hstore) Scan(value interface{}) error {
 | 
				
			||||||
 | 
						hstore := hstore.Hstore{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := hstore.Scan(value); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(hstore.Map) == 0 {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						*h = Hstore{}
 | 
				
			||||||
 | 
						for k := range hstore.Map {
 | 
				
			||||||
 | 
							if hstore.Map[k].Valid {
 | 
				
			||||||
 | 
								s := hstore.Map[k].String
 | 
				
			||||||
 | 
								(*h)[k] = &s
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								(*h)[k] = nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										3
									
								
								dialects/sqlite/sqlite.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								dialects/sqlite/sqlite.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					package sqlite
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import _ "github.com/mattn/go-sqlite3"
 | 
				
			||||||
@ -1,68 +0,0 @@
 | 
				
			|||||||
# Gorm Development
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Architecture
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Lets use below code to explain how it works:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    db.Where("name = ?", "jinzhu").Find(&users)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // equivalent code
 | 
					 | 
				
			||||||
    newdb := db.Where("name =?", "jinzhu")
 | 
					 | 
				
			||||||
    newdb.Find(&user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
`newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method.
 | 
					 | 
				
			||||||
`Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Callbacks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### Register a new callback
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    func updateCreated(scope *Scope) {
 | 
					 | 
				
			||||||
        if scope.HasColumn("Created") {
 | 
					 | 
				
			||||||
            scope.SetColumn("Created", NowFunc())
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    db.Callback().Create().Register("update_created_at", updateCreated)
 | 
					 | 
				
			||||||
    // register a callback for Create process
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### Delete an existing callback
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    db.Callback().Create().Remove("gorm:create")
 | 
					 | 
				
			||||||
    // delete callback `gorm:create` from Create callbacks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### Replace an existing callback
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    db.Callback().Create().Replace("gorm:create", newCreateFunction)
 | 
					 | 
				
			||||||
    // replace callback `gorm:create` with new function `newCreateFunction` for Create process
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### Register callback orders
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
 | 
					 | 
				
			||||||
    db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
 | 
					 | 
				
			||||||
    db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
 | 
					 | 
				
			||||||
    db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
 | 
					 | 
				
			||||||
    db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
 | 
					 | 
				
			||||||
    db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### Callback API
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API
 | 
					 | 
				
			||||||
							
								
								
									
										17
									
								
								errors.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								errors.go
									
									
									
									
									
								
							@ -6,25 +6,31 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	RecordNotFound       = errors.New("record not found")
 | 
						// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
 | 
				
			||||||
	InvalidSql           = errors.New("invalid sql")
 | 
						ErrRecordNotFound = errors.New("record not found")
 | 
				
			||||||
	NoNewAttrs           = errors.New("no new attributes")
 | 
						// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
 | 
				
			||||||
	NoValidTransaction   = errors.New("no valid transaction")
 | 
						ErrInvalidSQL = errors.New("invalid SQL")
 | 
				
			||||||
	CantStartTransaction = errors.New("can't start transaction")
 | 
						// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
 | 
				
			||||||
 | 
						ErrInvalidTransaction = errors.New("no valid transaction")
 | 
				
			||||||
 | 
						// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
 | 
				
			||||||
 | 
						ErrCantStartTransaction = errors.New("can't start transaction")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type errorsInterface interface {
 | 
					type errorsInterface interface {
 | 
				
			||||||
	GetErrors() []error
 | 
						GetErrors() []error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Errors contains all happened errors
 | 
				
			||||||
type Errors struct {
 | 
					type Errors struct {
 | 
				
			||||||
	errors []error
 | 
						errors []error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GetErrors get all happened errors
 | 
				
			||||||
func (errs Errors) GetErrors() []error {
 | 
					func (errs Errors) GetErrors() []error {
 | 
				
			||||||
	return errs.errors
 | 
						return errs.errors
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add add an error
 | 
				
			||||||
func (errs *Errors) Add(err error) {
 | 
					func (errs *Errors) Add(err error) {
 | 
				
			||||||
	if errors, ok := err.(errorsInterface); ok {
 | 
						if errors, ok := err.(errorsInterface); ok {
 | 
				
			||||||
		for _, err := range errors.GetErrors() {
 | 
							for _, err := range errors.GetErrors() {
 | 
				
			||||||
@ -40,6 +46,7 @@ func (errs *Errors) Add(err error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Error format happened errors
 | 
				
			||||||
func (errs Errors) Error() string {
 | 
					func (errs Errors) Error() string {
 | 
				
			||||||
	var errors = []string{}
 | 
						var errors = []string{}
 | 
				
			||||||
	for _, e := range errs.errors {
 | 
						for _, e := range errs.errors {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										43
									
								
								field.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								field.go
									
									
									
									
									
								
							@ -7,12 +7,14 @@ import (
 | 
				
			|||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Field model field definition
 | 
				
			||||||
type Field struct {
 | 
					type Field struct {
 | 
				
			||||||
	*StructField
 | 
						*StructField
 | 
				
			||||||
	IsBlank bool
 | 
						IsBlank bool
 | 
				
			||||||
	Field   reflect.Value
 | 
						Field   reflect.Value
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Set set a value to the field
 | 
				
			||||||
func (field *Field) Set(value interface{}) (err error) {
 | 
					func (field *Field) Set(value interface{}) (err error) {
 | 
				
			||||||
	if !field.Field.IsValid() {
 | 
						if !field.Field.IsValid() {
 | 
				
			||||||
		return errors.New("field value not valid")
 | 
							return errors.New("field value not valid")
 | 
				
			||||||
@ -56,35 +58,34 @@ func (field *Field) Set(value interface{}) (err error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Fields get value's fields
 | 
					// Fields get value's fields
 | 
				
			||||||
func (scope *Scope) Fields() map[string]*Field {
 | 
					func (scope *Scope) Fields() []*Field {
 | 
				
			||||||
	if scope.fields == nil {
 | 
						var (
 | 
				
			||||||
		fields := map[string]*Field{}
 | 
							fields             []*Field
 | 
				
			||||||
		modelStruct := scope.GetModelStruct()
 | 
							indirectScopeValue = scope.IndirectValue()
 | 
				
			||||||
 | 
							isStruct           = indirectScopeValue.Kind() == reflect.Struct
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		indirectValue := scope.IndirectValue()
 | 
						for _, structField := range scope.GetModelStruct().StructFields {
 | 
				
			||||||
		isStruct := indirectValue.Kind() == reflect.Struct
 | 
					 | 
				
			||||||
		for _, structField := range modelStruct.StructFields {
 | 
					 | 
				
			||||||
			if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
 | 
					 | 
				
			||||||
		if isStruct {
 | 
							if isStruct {
 | 
				
			||||||
					fields[structField.DBName] = getField(indirectValue, structField)
 | 
								fieldValue := indirectScopeValue
 | 
				
			||||||
				} else {
 | 
								for _, name := range structField.Names {
 | 
				
			||||||
					fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
 | 
									fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								fields = append(fields, &Field{StructField: structField, IsBlank: true})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		scope.fields = fields
 | 
					 | 
				
			||||||
	return fields
 | 
						return fields
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
	return scope.fields
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
 | 
					func (scope *Scope) fieldsMap() map[string]*Field {
 | 
				
			||||||
	field := &Field{StructField: structField}
 | 
						var results = map[string]*Field{}
 | 
				
			||||||
	for _, name := range structField.Names {
 | 
						for _, field := range scope.Fields() {
 | 
				
			||||||
		indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
 | 
							if field.IsNormal {
 | 
				
			||||||
 | 
								results[field.DBName] = field
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	field.Field = indirectValue
 | 
						}
 | 
				
			||||||
	field.IsBlank = isBlank(indirectValue)
 | 
						return results
 | 
				
			||||||
	return field
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -32,12 +32,16 @@ type CalculateFieldCategory struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestCalculateField(t *testing.T) {
 | 
					func TestCalculateField(t *testing.T) {
 | 
				
			||||||
	var field CalculateField
 | 
						var field CalculateField
 | 
				
			||||||
	fields := DB.NewScope(&field).Fields()
 | 
						var scope = DB.NewScope(&field)
 | 
				
			||||||
	if fields["children"].Relationship == nil || fields["category"].Relationship == nil {
 | 
						if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
 | 
				
			||||||
		t.Errorf("Should calculate fields correctly for the first time")
 | 
							t.Errorf("Should calculate fields correctly for the first time")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if field, ok := fields["embedded_name"]; !ok {
 | 
						if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
 | 
				
			||||||
 | 
							t.Errorf("Should calculate fields correctly for the first time")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if field, ok := scope.FieldByName("embedded_name"); !ok {
 | 
				
			||||||
		t.Errorf("should find embedded field")
 | 
							t.Errorf("should find embedded field")
 | 
				
			||||||
	} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
 | 
						} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
 | 
				
			||||||
		t.Errorf("should find embedded field's tag settings")
 | 
							t.Errorf("should find embedded field's tag settings")
 | 
				
			||||||
 | 
				
			|||||||
@ -1,83 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type foundation struct {
 | 
					 | 
				
			||||||
	commonDialect
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (foundation) BinVar(i int) string {
 | 
					 | 
				
			||||||
	return fmt.Sprintf("$%v", i)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (foundation) SupportLastInsertId() bool {
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
					 | 
				
			||||||
	switch value.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Bool:
 | 
					 | 
				
			||||||
		return "boolean"
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "serial"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "int"
 | 
					 | 
				
			||||||
	case reflect.Int64, reflect.Uint64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "bigserial"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "bigint"
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return "double"
 | 
					 | 
				
			||||||
	case reflect.String:
 | 
					 | 
				
			||||||
		if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("varchar(%d)", size)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "clob"
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().(time.Time); ok {
 | 
					 | 
				
			||||||
			return "datetime"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().([]byte); ok {
 | 
					 | 
				
			||||||
			return "blob"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s foundation) ReturningStr(tableName, key string) string {
 | 
					 | 
				
			||||||
	return fmt.Sprintf("RETURNING %v.%v", tableName, key)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s foundation) HasTable(scope *Scope, tableName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s foundation) RemoveIndex(scope *Scope, indexName string) {
 | 
					 | 
				
			||||||
	scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName)))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s foundation) CurrentDatabase(scope *Scope) (name string) {
 | 
					 | 
				
			||||||
	s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA")
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 65 KiB  | 
@ -7,40 +7,54 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// JoinTableHandlerInterface is an interface for how to handle many2many relations
 | 
				
			||||||
type JoinTableHandlerInterface interface {
 | 
					type JoinTableHandlerInterface interface {
 | 
				
			||||||
 | 
						// initialize join table handler
 | 
				
			||||||
	Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
 | 
						Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
 | 
				
			||||||
 | 
						// Table return join table's table name
 | 
				
			||||||
	Table(db *DB) string
 | 
						Table(db *DB) string
 | 
				
			||||||
 | 
						// Add create relationship in join table for source and destination
 | 
				
			||||||
	Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
 | 
						Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
 | 
				
			||||||
 | 
						// Delete delete relationship in join table for sources
 | 
				
			||||||
	Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
 | 
						Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
 | 
				
			||||||
 | 
						// JoinWith query with `Join` conditions
 | 
				
			||||||
	JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
 | 
						JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
 | 
				
			||||||
 | 
						// SourceForeignKeys return source foreign keys
 | 
				
			||||||
	SourceForeignKeys() []JoinTableForeignKey
 | 
						SourceForeignKeys() []JoinTableForeignKey
 | 
				
			||||||
 | 
						// DestinationForeignKeys return destination foreign keys
 | 
				
			||||||
	DestinationForeignKeys() []JoinTableForeignKey
 | 
						DestinationForeignKeys() []JoinTableForeignKey
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// JoinTableForeignKey join table foreign key struct
 | 
				
			||||||
type JoinTableForeignKey struct {
 | 
					type JoinTableForeignKey struct {
 | 
				
			||||||
	DBName            string
 | 
						DBName            string
 | 
				
			||||||
	AssociationDBName string
 | 
						AssociationDBName string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// JoinTableSource is a struct that contains model type and foreign keys
 | 
				
			||||||
type JoinTableSource struct {
 | 
					type JoinTableSource struct {
 | 
				
			||||||
	ModelType   reflect.Type
 | 
						ModelType   reflect.Type
 | 
				
			||||||
	ForeignKeys []JoinTableForeignKey
 | 
						ForeignKeys []JoinTableForeignKey
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// JoinTableHandler default join table handler
 | 
				
			||||||
type JoinTableHandler struct {
 | 
					type JoinTableHandler struct {
 | 
				
			||||||
	TableName   string          `sql:"-"`
 | 
						TableName   string          `sql:"-"`
 | 
				
			||||||
	Source      JoinTableSource `sql:"-"`
 | 
						Source      JoinTableSource `sql:"-"`
 | 
				
			||||||
	Destination JoinTableSource `sql:"-"`
 | 
						Destination JoinTableSource `sql:"-"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SourceForeignKeys return source foreign keys
 | 
				
			||||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
 | 
					func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
 | 
				
			||||||
	return s.Source.ForeignKeys
 | 
						return s.Source.ForeignKeys
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DestinationForeignKeys return destination foreign keys
 | 
				
			||||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
 | 
					func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
 | 
				
			||||||
	return s.Destination.ForeignKeys
 | 
						return s.Destination.ForeignKeys
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Setup initialize a default join table handler
 | 
				
			||||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
 | 
					func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
 | 
				
			||||||
	s.TableName = tableName
 | 
						s.TableName = tableName
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Table return join table's table name
 | 
				
			||||||
func (s JoinTableHandler) Table(db *DB) string {
 | 
					func (s JoinTableHandler) Table(db *DB) string {
 | 
				
			||||||
	return s.TableName
 | 
						return s.TableName
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
 | 
					func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
 | 
				
			||||||
	values := map[string]interface{}{}
 | 
						values := map[string]interface{}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, source := range sources {
 | 
						for _, source := range sources {
 | 
				
			||||||
@ -74,20 +89,25 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		if s.Source.ModelType == modelType {
 | 
							if s.Source.ModelType == modelType {
 | 
				
			||||||
			for _, foreignKey := range s.Source.ForeignKeys {
 | 
								for _, foreignKey := range s.Source.ForeignKeys {
 | 
				
			||||||
				values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
 | 
									if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
 | 
				
			||||||
 | 
										values[foreignKey.DBName] = field.Field.Interface()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else if s.Destination.ModelType == modelType {
 | 
							} else if s.Destination.ModelType == modelType {
 | 
				
			||||||
			for _, foreignKey := range s.Destination.ForeignKeys {
 | 
								for _, foreignKey := range s.Destination.ForeignKeys {
 | 
				
			||||||
				values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
 | 
									if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
 | 
				
			||||||
 | 
										values[foreignKey.DBName] = field.Field.Interface()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return values
 | 
						return values
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
 | 
					// Add create relationship in join table for source and destination
 | 
				
			||||||
 | 
					func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
 | 
				
			||||||
	scope := db.NewScope("")
 | 
						scope := db.NewScope("")
 | 
				
			||||||
	searchMap := s.GetSearchMap(db, source1, source2)
 | 
						searchMap := s.getSearchMap(db, source, destination)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var assignColumns, binVars, conditions []string
 | 
						var assignColumns, binVars, conditions []string
 | 
				
			||||||
	var values []interface{}
 | 
						var values []interface{}
 | 
				
			||||||
@ -116,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
 | 
				
			|||||||
	return db.Exec(sql, values...).Error
 | 
						return db.Exec(sql, values...).Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Delete delete relationship in join table for sources
 | 
				
			||||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
 | 
					func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		scope      = db.NewScope(nil)
 | 
							scope      = db.NewScope(nil)
 | 
				
			||||||
@ -123,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
 | 
				
			|||||||
		values     []interface{}
 | 
							values     []interface{}
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for key, value := range s.GetSearchMap(db, sources...) {
 | 
						for key, value := range s.getSearchMap(db, sources...) {
 | 
				
			||||||
		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
 | 
							conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
 | 
				
			||||||
		values = append(values, value)
 | 
							values = append(values, value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -131,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
 | 
				
			|||||||
	return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
 | 
						return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// JoinWith query with `Join` conditions
 | 
				
			||||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
 | 
					func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		scope           = db.NewScope(source)
 | 
							scope           = db.NewScope(source)
 | 
				
			||||||
@ -151,10 +173,12 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		for _, foreignKey := range s.Source.ForeignKeys {
 | 
							for _, foreignKey := range s.Source.ForeignKeys {
 | 
				
			||||||
			foreignDBNames = append(foreignDBNames, foreignKey.DBName)
 | 
								foreignDBNames = append(foreignDBNames, foreignKey.DBName)
 | 
				
			||||||
			foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
 | 
								if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
 | 
				
			||||||
 | 
									foreignFieldNames = append(foreignFieldNames, field.Name)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
 | 
							foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		var condString string
 | 
							var condString string
 | 
				
			||||||
		if len(foreignFieldValues) > 0 {
 | 
							if len(foreignFieldValues) > 0 {
 | 
				
			||||||
@ -165,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
 | 
								condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			keys := scope.getColumnAsArray(foreignFieldNames)
 | 
								keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
 | 
				
			||||||
			values = append(values, toQueryValues(keys))
 | 
								values = append(values, toQueryValues(keys))
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			condString = fmt.Sprintf("1 <> 1")
 | 
								condString = fmt.Sprintf("1 <> 1")
 | 
				
			||||||
@ -173,8 +197,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
 | 
							return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
 | 
				
			||||||
			Where(condString, toQueryValues(foreignFieldValues)...)
 | 
								Where(condString, toQueryValues(foreignFieldValues)...)
 | 
				
			||||||
	} else {
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db.Error = errors.New("wrong source type for join table handler")
 | 
						db.Error = errors.New("wrong source type for join table handler")
 | 
				
			||||||
	return db
 | 
						return db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ type PersonAddress struct {
 | 
				
			|||||||
	gorm.JoinTableHandler
 | 
						gorm.JoinTableHandler
 | 
				
			||||||
	PersonID  int
 | 
						PersonID  int
 | 
				
			||||||
	AddressID int
 | 
						AddressID int
 | 
				
			||||||
	DeletedAt time.Time
 | 
						DeletedAt *time.Time
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										57
									
								
								logger.go
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								logger.go
									
									
									
									
									
								
							@ -8,25 +8,28 @@ import (
 | 
				
			|||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
						"unicode"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
 | 
				
			||||||
 | 
						sqlRegexp     = regexp.MustCompile(`(\$\d+)|\?`)
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type logger interface {
 | 
					type logger interface {
 | 
				
			||||||
	Print(v ...interface{})
 | 
						Print(v ...interface{})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type LogWriter interface {
 | 
					type logWriter interface {
 | 
				
			||||||
	Println(v ...interface{})
 | 
						Println(v ...interface{})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Logger default logger
 | 
				
			||||||
type Logger struct {
 | 
					type Logger struct {
 | 
				
			||||||
	LogWriter
 | 
						logWriter
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
 | 
					// Print format & print log
 | 
				
			||||||
 | 
					 | 
				
			||||||
// Format log
 | 
					 | 
				
			||||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (logger Logger) Print(values ...interface{}) {
 | 
					func (logger Logger) Print(values ...interface{}) {
 | 
				
			||||||
	if len(values) > 1 {
 | 
						if len(values) > 1 {
 | 
				
			||||||
		level := values[0]
 | 
							level := values[0]
 | 
				
			||||||
@ -38,29 +41,44 @@ func (logger Logger) Print(values ...interface{}) {
 | 
				
			|||||||
			// duration
 | 
								// duration
 | 
				
			||||||
			messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
 | 
								messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
 | 
				
			||||||
			// sql
 | 
								// sql
 | 
				
			||||||
			var formatedValues []interface{}
 | 
								var sql string
 | 
				
			||||||
 | 
								var formattedValues []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, value := range values[4].([]interface{}) {
 | 
								for _, value := range values[4].([]interface{}) {
 | 
				
			||||||
				indirectValue := reflect.Indirect(reflect.ValueOf(value))
 | 
									indirectValue := reflect.Indirect(reflect.ValueOf(value))
 | 
				
			||||||
				if indirectValue.IsValid() {
 | 
									if indirectValue.IsValid() {
 | 
				
			||||||
					value = indirectValue.Interface()
 | 
										value = indirectValue.Interface()
 | 
				
			||||||
					if t, ok := value.(time.Time); ok {
 | 
										if t, ok := value.(time.Time); ok {
 | 
				
			||||||
						formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
 | 
											formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
 | 
				
			||||||
					} else if b, ok := value.([]byte); ok {
 | 
										} else if b, ok := value.([]byte); ok {
 | 
				
			||||||
						formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
 | 
											if str := string(b); isPrintable(str) {
 | 
				
			||||||
 | 
												formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
 | 
				
			||||||
 | 
											} else {
 | 
				
			||||||
 | 
												formattedValues = append(formattedValues, "'<binary>'")
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
					} else if r, ok := value.(driver.Valuer); ok {
 | 
										} else if r, ok := value.(driver.Valuer); ok {
 | 
				
			||||||
						if value, err := r.Value(); err == nil && value != nil {
 | 
											if value, err := r.Value(); err == nil && value != nil {
 | 
				
			||||||
							formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
 | 
												formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
 | 
				
			||||||
						} else {
 | 
											} else {
 | 
				
			||||||
							formatedValues = append(formatedValues, "NULL")
 | 
												formattedValues = append(formattedValues, "NULL")
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					} else {
 | 
										} else {
 | 
				
			||||||
						formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
 | 
											formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
 | 
										formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
 | 
					
 | 
				
			||||||
 | 
								var formattedValuesLength = len(formattedValues)
 | 
				
			||||||
 | 
								for index, value := range sqlRegexp.Split(values[3].(string), -1) {
 | 
				
			||||||
 | 
									sql += value
 | 
				
			||||||
 | 
									if index < formattedValuesLength {
 | 
				
			||||||
 | 
										sql += formattedValues[index]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								messages = append(messages, sql)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			messages = append(messages, "\033[31;1m")
 | 
								messages = append(messages, "\033[31;1m")
 | 
				
			||||||
			messages = append(messages, values[2:]...)
 | 
								messages = append(messages, values[2:]...)
 | 
				
			||||||
@ -69,3 +87,12 @@ func (logger Logger) Print(values ...interface{}) {
 | 
				
			|||||||
		logger.Println(messages...)
 | 
							logger.Println(messages...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func isPrintable(s string) bool {
 | 
				
			||||||
 | 
						for _, r := range s {
 | 
				
			||||||
 | 
							if !unicode.IsPrint(r) {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										279
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										279
									
								
								main.go
									
									
									
									
									
								
							@ -6,24 +6,14 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NowFunc returns current time, this function is exported in order to be able
 | 
					// DB contains information for current db connection
 | 
				
			||||||
// to give the flexibility to the developer to customize it according to their
 | 
					 | 
				
			||||||
// needs
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//   e.g: return time.Now().UTC()
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
var NowFunc = func() time.Time {
 | 
					 | 
				
			||||||
	return time.Now()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type DB struct {
 | 
					type DB struct {
 | 
				
			||||||
	Value             interface{}
 | 
						Value             interface{}
 | 
				
			||||||
	Error             error
 | 
						Error             error
 | 
				
			||||||
	RowsAffected      int64
 | 
						RowsAffected      int64
 | 
				
			||||||
	callback          *callback
 | 
						callbacks         *Callback
 | 
				
			||||||
	db                sqlCommon
 | 
						db                sqlCommon
 | 
				
			||||||
	parent            *DB
 | 
						parent            *DB
 | 
				
			||||||
	search            *search
 | 
						search            *search
 | 
				
			||||||
@ -36,7 +26,18 @@ type DB struct {
 | 
				
			|||||||
	joinTableHandlers map[string]JoinTableHandler
 | 
						joinTableHandlers map[string]JoinTableHandler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Open(dialect string, args ...interface{}) (DB, error) {
 | 
					// Open initialize a new db connection, need to import driver first, e.g:
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     import _ "github.com/go-sql-driver/mysql"
 | 
				
			||||||
 | 
					//     func main() {
 | 
				
			||||||
 | 
					//       db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
 | 
				
			||||||
 | 
					//     }
 | 
				
			||||||
 | 
					// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
 | 
				
			||||||
 | 
					//    import _ "github.com/jinzhu/gorm/dialects/mysql"
 | 
				
			||||||
 | 
					//    // import _ "github.com/jinzhu/gorm/dialects/postgres"
 | 
				
			||||||
 | 
					//    // import _ "github.com/jinzhu/gorm/dialects/sqlite"
 | 
				
			||||||
 | 
					//    // import _ "github.com/jinzhu/gorm/dialects/mssql"
 | 
				
			||||||
 | 
					func Open(dialect string, args ...interface{}) (*DB, error) {
 | 
				
			||||||
	var db DB
 | 
						var db DB
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -44,7 +45,7 @@ func Open(dialect string, args ...interface{}) (DB, error) {
 | 
				
			|||||||
		err = errors.New("invalid database source")
 | 
							err = errors.New("invalid database source")
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		var source string
 | 
							var source string
 | 
				
			||||||
		var dbSql sqlCommon
 | 
							var dbSQL sqlCommon
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		switch value := args[0].(type) {
 | 
							switch value := args[0].(type) {
 | 
				
			||||||
		case string:
 | 
							case string:
 | 
				
			||||||
@ -55,22 +56,19 @@ func Open(dialect string, args ...interface{}) (DB, error) {
 | 
				
			|||||||
				driver = value
 | 
									driver = value
 | 
				
			||||||
				source = args[1].(string)
 | 
									source = args[1].(string)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if driver == "foundation" {
 | 
								dbSQL, err = sql.Open(driver, source)
 | 
				
			||||||
				driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			dbSql, err = sql.Open(driver, source)
 | 
					 | 
				
			||||||
		case sqlCommon:
 | 
							case sqlCommon:
 | 
				
			||||||
			source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
 | 
								source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
 | 
				
			||||||
			dbSql = value
 | 
								dbSQL = value
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		db = DB{
 | 
							db = DB{
 | 
				
			||||||
			dialect:  NewDialect(dialect),
 | 
								dialect:   newDialect(dialect, dbSQL.(*sql.DB)),
 | 
				
			||||||
			logger:    defaultLogger,
 | 
								logger:    defaultLogger,
 | 
				
			||||||
			callback: DefaultCallback,
 | 
								callbacks: DefaultCallback,
 | 
				
			||||||
			source:    source,
 | 
								source:    source,
 | 
				
			||||||
			values:    map[string]interface{}{},
 | 
								values:    map[string]interface{}{},
 | 
				
			||||||
			db:       dbSql,
 | 
								db:        dbSQL,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		db.parent = &db
 | 
							db.parent = &db
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -79,17 +77,20 @@ func Open(dialect string, args ...interface{}) (DB, error) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return db, err
 | 
						return &db, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close close current db connection
 | 
				
			||||||
func (s *DB) Close() error {
 | 
					func (s *DB) Close() error {
 | 
				
			||||||
	return s.parent.db.(*sql.DB).Close()
 | 
						return s.parent.db.(*sql.DB).Close()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DB get `*sql.DB` from current connection
 | 
				
			||||||
func (s *DB) DB() *sql.DB {
 | 
					func (s *DB) DB() *sql.DB {
 | 
				
			||||||
	return s.db.(*sql.DB)
 | 
						return s.db.(*sql.DB)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// New clone a new db connection without search conditions
 | 
				
			||||||
func (s *DB) New() *DB {
 | 
					func (s *DB) New() *DB {
 | 
				
			||||||
	clone := s.clone()
 | 
						clone := s.clone()
 | 
				
			||||||
	clone.search = nil
 | 
						clone.search = nil
 | 
				
			||||||
@ -97,29 +98,32 @@ func (s *DB) New() *DB {
 | 
				
			|||||||
	return clone
 | 
						return clone
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewScope create scope for callbacks, including DB's search information
 | 
					// NewScope create a scope for current operation
 | 
				
			||||||
func (db *DB) NewScope(value interface{}) *Scope {
 | 
					func (s *DB) NewScope(value interface{}) *Scope {
 | 
				
			||||||
	dbClone := db.clone()
 | 
						dbClone := s.clone()
 | 
				
			||||||
	dbClone.Value = value
 | 
						dbClone.Value = value
 | 
				
			||||||
	return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
 | 
						return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
 | 
					// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
 | 
				
			||||||
// Use of this method is discouraged. It's mainly intended to allow
 | 
					 | 
				
			||||||
// coexistence with legacy non-GORM code.
 | 
					 | 
				
			||||||
func (s *DB) CommonDB() sqlCommon {
 | 
					func (s *DB) CommonDB() sqlCommon {
 | 
				
			||||||
	return s.db
 | 
						return s.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) Callback() *callback {
 | 
					// Callback return `Callbacks` container, you could add/change/delete callbacks with it
 | 
				
			||||||
	s.parent.callback = s.parent.callback.clone()
 | 
					//     db.Callback().Create().Register("update_created_at", updateCreated)
 | 
				
			||||||
	return s.parent.callback
 | 
					// Refer https://jinzhu.github.io/gorm/development.html#callbacks
 | 
				
			||||||
 | 
					func (s *DB) Callback() *Callback {
 | 
				
			||||||
 | 
						s.parent.callbacks = s.parent.callbacks.clone()
 | 
				
			||||||
 | 
						return s.parent.callbacks
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) SetLogger(l logger) {
 | 
					// SetLogger replace default logger
 | 
				
			||||||
	s.logger = l
 | 
					func (s *DB) SetLogger(log logger) {
 | 
				
			||||||
 | 
						s.logger = log
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
 | 
				
			||||||
func (s *DB) LogMode(enable bool) *DB {
 | 
					func (s *DB) LogMode(enable bool) *DB {
 | 
				
			||||||
	if enable {
 | 
						if enable {
 | 
				
			||||||
		s.logMode = 2
 | 
							s.logMode = 2
 | 
				
			||||||
@ -129,55 +133,82 @@ func (s *DB) LogMode(enable bool) *DB {
 | 
				
			|||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SingularTable use singular table by default
 | 
				
			||||||
func (s *DB) SingularTable(enable bool) {
 | 
					func (s *DB) SingularTable(enable bool) {
 | 
				
			||||||
	modelStructsMap = newModelStructsMap()
 | 
						modelStructsMap = newModelStructsMap()
 | 
				
			||||||
	s.parent.singularTable = enable
 | 
						s.parent.singularTable = enable
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
 | 
				
			||||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
 | 
					func (s *DB) Where(query interface{}, args ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Where(query, args...).db
 | 
						return s.clone().search.Where(query, args...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Or filter records that match before conditions or this one, similar to `Where`
 | 
				
			||||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
 | 
					func (s *DB) Or(query interface{}, args ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Or(query, args...).db
 | 
						return s.clone().search.Or(query, args...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Not filter records that don't match current conditions, similar to `Where`
 | 
				
			||||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
 | 
					func (s *DB) Not(query interface{}, args ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Not(query, args...).db
 | 
						return s.clone().search.Not(query, args...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) Limit(value interface{}) *DB {
 | 
					// Limit specify the number of records to be retrieved
 | 
				
			||||||
	return s.clone().search.Limit(value).db
 | 
					func (s *DB) Limit(limit int) *DB {
 | 
				
			||||||
 | 
						return s.clone().search.Limit(limit).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) Offset(value interface{}) *DB {
 | 
					// Offset specify the number of records to skip before starting to return the records
 | 
				
			||||||
	return s.clone().search.Offset(value).db
 | 
					func (s *DB) Offset(offset int) *DB {
 | 
				
			||||||
 | 
						return s.clone().search.Offset(offset).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
 | 
				
			||||||
func (s *DB) Order(value string, reorder ...bool) *DB {
 | 
					func (s *DB) Order(value string, reorder ...bool) *DB {
 | 
				
			||||||
	return s.clone().search.Order(value, reorder...).db
 | 
						return s.clone().search.Order(value, reorder...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
 | 
				
			||||||
 | 
					// When creating/updating, specify fields that you want to save to database
 | 
				
			||||||
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
 | 
					func (s *DB) Select(query interface{}, args ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Select(query, args...).db
 | 
						return s.clone().search.Select(query, args...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Omit specify fields that you want to ignore when saving to database for creating, updating
 | 
				
			||||||
func (s *DB) Omit(columns ...string) *DB {
 | 
					func (s *DB) Omit(columns ...string) *DB {
 | 
				
			||||||
	return s.clone().search.Omit(columns...).db
 | 
						return s.clone().search.Omit(columns...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Group specify the group method on the find
 | 
				
			||||||
func (s *DB) Group(query string) *DB {
 | 
					func (s *DB) Group(query string) *DB {
 | 
				
			||||||
	return s.clone().search.Group(query).db
 | 
						return s.clone().search.Group(query).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Having specify HAVING conditions for GROUP BY
 | 
				
			||||||
func (s *DB) Having(query string, values ...interface{}) *DB {
 | 
					func (s *DB) Having(query string, values ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Having(query, values...).db
 | 
						return s.clone().search.Having(query, values...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) Joins(query string) *DB {
 | 
					// Joins specify Joins conditions
 | 
				
			||||||
	return s.clone().search.Joins(query).db
 | 
					//     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | 
				
			||||||
 | 
					func (s *DB) Joins(query string, args ...interface{}) *DB {
 | 
				
			||||||
 | 
						return s.clone().search.Joins(query, args...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
 | 
				
			||||||
 | 
					//     func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
 | 
				
			||||||
 | 
					//         return db.Where("amount > ?", 1000)
 | 
				
			||||||
 | 
					//     }
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
 | 
				
			||||||
 | 
					//         return func (db *gorm.DB) *gorm.DB {
 | 
				
			||||||
 | 
					//             return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
 | 
				
			||||||
 | 
					//         }
 | 
				
			||||||
 | 
					//     }
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
 | 
				
			||||||
 | 
					// Refer https://jinzhu.github.io/gorm/curd.html#scopes
 | 
				
			||||||
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
 | 
					func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
 | 
				
			||||||
	for _, f := range funcs {
 | 
						for _, f := range funcs {
 | 
				
			||||||
		s = f(s)
 | 
							s = f(s)
 | 
				
			||||||
@ -185,60 +216,91 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
 | 
				
			|||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete
 | 
				
			||||||
func (s *DB) Unscoped() *DB {
 | 
					func (s *DB) Unscoped() *DB {
 | 
				
			||||||
	return s.clone().search.unscoped().db
 | 
						return s.clone().search.unscoped().db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
 | 
				
			||||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
 | 
					func (s *DB) Attrs(attrs ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Attrs(attrs...).db
 | 
						return s.clone().search.Attrs(attrs...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
 | 
				
			||||||
func (s *DB) Assign(attrs ...interface{}) *DB {
 | 
					func (s *DB) Assign(attrs ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Assign(attrs...).db
 | 
						return s.clone().search.Assign(attrs...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// First find first record that match given conditions, order by primary key
 | 
				
			||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
 | 
					func (s *DB) First(out interface{}, where ...interface{}) *DB {
 | 
				
			||||||
	newScope := s.clone().NewScope(out)
 | 
						newScope := s.clone().NewScope(out)
 | 
				
			||||||
	newScope.Search.Limit(1)
 | 
						newScope.Search.Limit(1)
 | 
				
			||||||
	return newScope.Set("gorm:order_by_primary_key", "ASC").
 | 
						return newScope.Set("gorm:order_by_primary_key", "ASC").
 | 
				
			||||||
		inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
 | 
							inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Last find last record that match given conditions, order by primary key
 | 
				
			||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
 | 
					func (s *DB) Last(out interface{}, where ...interface{}) *DB {
 | 
				
			||||||
	newScope := s.clone().NewScope(out)
 | 
						newScope := s.clone().NewScope(out)
 | 
				
			||||||
	newScope.Search.Limit(1)
 | 
						newScope.Search.Limit(1)
 | 
				
			||||||
	return newScope.Set("gorm:order_by_primary_key", "DESC").
 | 
						return newScope.Set("gorm:order_by_primary_key", "DESC").
 | 
				
			||||||
		inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
 | 
							inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Find find records that match given conditions
 | 
				
			||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
 | 
					func (s *DB) Find(out interface{}, where ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
 | 
						return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Scan scan value to a struct
 | 
				
			||||||
func (s *DB) Scan(dest interface{}) *DB {
 | 
					func (s *DB) Scan(dest interface{}) *DB {
 | 
				
			||||||
	return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
 | 
						return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Row return `*sql.Row` with given conditions
 | 
				
			||||||
func (s *DB) Row() *sql.Row {
 | 
					func (s *DB) Row() *sql.Row {
 | 
				
			||||||
	return s.NewScope(s.Value).row()
 | 
						return s.NewScope(s.Value).row()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Rows return `*sql.Rows` with given conditions
 | 
				
			||||||
func (s *DB) Rows() (*sql.Rows, error) {
 | 
					func (s *DB) Rows() (*sql.Rows, error) {
 | 
				
			||||||
	return s.NewScope(s.Value).rows()
 | 
						return s.NewScope(s.Value).rows()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ScanRows scan `*sql.Rows` to give struct
 | 
				
			||||||
 | 
					func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							clone        = s.clone()
 | 
				
			||||||
 | 
							scope        = clone.NewScope(result)
 | 
				
			||||||
 | 
							columns, err = rows.Columns()
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if clone.AddError(err) == nil {
 | 
				
			||||||
 | 
							scope.scan(rows, columns, scope.fieldsMap())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return clone.Error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Pluck used to query single column from a model as a map
 | 
				
			||||||
 | 
					//     var ages []int64
 | 
				
			||||||
 | 
					//     db.Find(&users).Pluck("age", &ages)
 | 
				
			||||||
func (s *DB) Pluck(column string, value interface{}) *DB {
 | 
					func (s *DB) Pluck(column string, value interface{}) *DB {
 | 
				
			||||||
	return s.NewScope(s.Value).pluck(column, value).db
 | 
						return s.NewScope(s.Value).pluck(column, value).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Count get how many records for a model
 | 
				
			||||||
func (s *DB) Count(value interface{}) *DB {
 | 
					func (s *DB) Count(value interface{}) *DB {
 | 
				
			||||||
	return s.NewScope(s.Value).count(value).db
 | 
						return s.NewScope(s.Value).count(value).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Related get related associations
 | 
				
			||||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
 | 
					func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
 | 
				
			||||||
	return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
 | 
						return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
 | 
				
			||||||
 | 
					// https://jinzhu.github.io/gorm/curd.html#firstorinit
 | 
				
			||||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
 | 
					func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
 | 
				
			||||||
	c := s.clone()
 | 
						c := s.clone()
 | 
				
			||||||
	if result := c.First(out, where...); result.Error != nil {
 | 
						if result := c.First(out, where...); result.Error != nil {
 | 
				
			||||||
@ -247,82 +309,100 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		c.NewScope(out).inlineCondition(where...).initialize()
 | 
							c.NewScope(out).inlineCondition(where...).initialize()
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
 | 
							c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return c
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
 | 
				
			||||||
 | 
					// https://jinzhu.github.io/gorm/curd.html#firstorcreate
 | 
				
			||||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
 | 
					func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
 | 
				
			||||||
	c := s.clone()
 | 
						c := s.clone()
 | 
				
			||||||
	if result := c.First(out, where...); result.Error != nil {
 | 
						if result := c.First(out, where...); result.Error != nil {
 | 
				
			||||||
		if !result.RecordNotFound() {
 | 
							if !result.RecordNotFound() {
 | 
				
			||||||
			return result
 | 
								return result
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
 | 
							c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error)
 | 
				
			||||||
	} else if len(c.search.assignAttrs) > 0 {
 | 
						} else if len(c.search.assignAttrs) > 0 {
 | 
				
			||||||
		c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
 | 
							c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return c
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
 | 
				
			||||||
func (s *DB) Update(attrs ...interface{}) *DB {
 | 
					func (s *DB) Update(attrs ...interface{}) *DB {
 | 
				
			||||||
	return s.Updates(toSearchableMap(attrs...), true)
 | 
						return s.Updates(toSearchableMap(attrs...), true)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
 | 
				
			||||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
 | 
					func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
 | 
				
			||||||
	return s.clone().NewScope(s.Value).
 | 
						return s.clone().NewScope(s.Value).
 | 
				
			||||||
		Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
 | 
							Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
 | 
				
			||||||
		InstanceSet("gorm:update_interface", values).
 | 
							InstanceSet("gorm:update_interface", values).
 | 
				
			||||||
		callCallbacks(s.parent.callback.updates).db
 | 
							callCallbacks(s.parent.callbacks.updates).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
 | 
				
			||||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
 | 
					func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
 | 
				
			||||||
	return s.UpdateColumns(toSearchableMap(attrs...))
 | 
						return s.UpdateColumns(toSearchableMap(attrs...))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
 | 
				
			||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
 | 
					func (s *DB) UpdateColumns(values interface{}) *DB {
 | 
				
			||||||
	return s.clone().NewScope(s.Value).
 | 
						return s.clone().NewScope(s.Value).
 | 
				
			||||||
		Set("gorm:update_column", true).
 | 
							Set("gorm:update_column", true).
 | 
				
			||||||
		Set("gorm:save_associations", false).
 | 
							Set("gorm:save_associations", false).
 | 
				
			||||||
		InstanceSet("gorm:update_interface", values).
 | 
							InstanceSet("gorm:update_interface", values).
 | 
				
			||||||
		callCallbacks(s.parent.callback.updates).db
 | 
							callCallbacks(s.parent.callbacks.updates).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Save update value in database, if the value doesn't have primary key, will insert it
 | 
				
			||||||
func (s *DB) Save(value interface{}) *DB {
 | 
					func (s *DB) Save(value interface{}) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(value)
 | 
						scope := s.clone().NewScope(value)
 | 
				
			||||||
	if scope.PrimaryKeyZero() {
 | 
						if scope.PrimaryKeyZero() {
 | 
				
			||||||
		return scope.callCallbacks(s.parent.callback.creates).db
 | 
							return scope.callCallbacks(s.parent.callbacks.creates).db
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return scope.callCallbacks(s.parent.callback.updates).db
 | 
						return scope.callCallbacks(s.parent.callbacks.updates).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create insert the value into database
 | 
				
			||||||
func (s *DB) Create(value interface{}) *DB {
 | 
					func (s *DB) Create(value interface{}) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(value)
 | 
						scope := s.clone().NewScope(value)
 | 
				
			||||||
	return scope.callCallbacks(s.parent.callback.creates).db
 | 
						return scope.callCallbacks(s.parent.callbacks.creates).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
 | 
				
			||||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
 | 
					func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
 | 
						return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Raw use raw sql as conditions, won't run it unless invoked by other methods
 | 
				
			||||||
 | 
					//    db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
 | 
				
			||||||
func (s *DB) Raw(sql string, values ...interface{}) *DB {
 | 
					func (s *DB) Raw(sql string, values ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Raw(true).Where(sql, values...).db
 | 
						return s.clone().search.Raw(true).Where(sql, values...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Exec execute raw sql
 | 
				
			||||||
func (s *DB) Exec(sql string, values ...interface{}) *DB {
 | 
					func (s *DB) Exec(sql string, values ...interface{}) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(nil)
 | 
						scope := s.clone().NewScope(nil)
 | 
				
			||||||
	generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
 | 
						generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
 | 
				
			||||||
	generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
 | 
						generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
 | 
				
			||||||
	scope.Raw(generatedSql)
 | 
						scope.Raw(generatedSQL)
 | 
				
			||||||
	return scope.Exec().db
 | 
						return scope.Exec().db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Model specify the model you would like to run db operations
 | 
				
			||||||
 | 
					//    // update all users's name to `hello`
 | 
				
			||||||
 | 
					//    db.Model(&User{}).Update("name", "hello")
 | 
				
			||||||
 | 
					//    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | 
				
			||||||
 | 
					//    db.Model(&user).Update("name", "hello")
 | 
				
			||||||
func (s *DB) Model(value interface{}) *DB {
 | 
					func (s *DB) Model(value interface{}) *DB {
 | 
				
			||||||
	c := s.clone()
 | 
						c := s.clone()
 | 
				
			||||||
	c.Value = value
 | 
						c.Value = value
 | 
				
			||||||
	return c
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Table specify the table you would like to run db operations
 | 
				
			||||||
func (s *DB) Table(name string) *DB {
 | 
					func (s *DB) Table(name string) *DB {
 | 
				
			||||||
	clone := s.clone()
 | 
						clone := s.clone()
 | 
				
			||||||
	clone.search.Table(name)
 | 
						clone.search.Table(name)
 | 
				
			||||||
@ -330,10 +410,12 @@ func (s *DB) Table(name string) *DB {
 | 
				
			|||||||
	return clone
 | 
						return clone
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Debug start debug mode
 | 
				
			||||||
func (s *DB) Debug() *DB {
 | 
					func (s *DB) Debug() *DB {
 | 
				
			||||||
	return s.clone().LogMode(true)
 | 
						return s.clone().LogMode(true)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Begin begin a transaction
 | 
				
			||||||
func (s *DB) Begin() *DB {
 | 
					func (s *DB) Begin() *DB {
 | 
				
			||||||
	c := s.clone()
 | 
						c := s.clone()
 | 
				
			||||||
	if db, ok := c.db.(sqlDb); ok {
 | 
						if db, ok := c.db.(sqlDb); ok {
 | 
				
			||||||
@ -341,46 +423,56 @@ func (s *DB) Begin() *DB {
 | 
				
			|||||||
		c.db = interface{}(tx).(sqlCommon)
 | 
							c.db = interface{}(tx).(sqlCommon)
 | 
				
			||||||
		c.AddError(err)
 | 
							c.AddError(err)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		c.AddError(CantStartTransaction)
 | 
							c.AddError(ErrCantStartTransaction)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return c
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Commit commit a transaction
 | 
				
			||||||
func (s *DB) Commit() *DB {
 | 
					func (s *DB) Commit() *DB {
 | 
				
			||||||
	if db, ok := s.db.(sqlTx); ok {
 | 
						if db, ok := s.db.(sqlTx); ok {
 | 
				
			||||||
		s.AddError(db.Commit())
 | 
							s.AddError(db.Commit())
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		s.AddError(NoValidTransaction)
 | 
							s.AddError(ErrInvalidTransaction)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Rollback rollback a transaction
 | 
				
			||||||
func (s *DB) Rollback() *DB {
 | 
					func (s *DB) Rollback() *DB {
 | 
				
			||||||
	if db, ok := s.db.(sqlTx); ok {
 | 
						if db, ok := s.db.(sqlTx); ok {
 | 
				
			||||||
		s.AddError(db.Rollback())
 | 
							s.AddError(db.Rollback())
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		s.AddError(NoValidTransaction)
 | 
							s.AddError(ErrInvalidTransaction)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewRecord check if value's primary key is blank
 | 
				
			||||||
func (s *DB) NewRecord(value interface{}) bool {
 | 
					func (s *DB) NewRecord(value interface{}) bool {
 | 
				
			||||||
	return s.clone().NewScope(value).PrimaryKeyZero()
 | 
						return s.clone().NewScope(value).PrimaryKeyZero()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RecordNotFound check if returning ErrRecordNotFound error
 | 
				
			||||||
func (s *DB) RecordNotFound() bool {
 | 
					func (s *DB) RecordNotFound() bool {
 | 
				
			||||||
	return s.Error == RecordNotFound
 | 
						for _, err := range s.GetErrors() {
 | 
				
			||||||
 | 
							if err == ErrRecordNotFound {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Migrations
 | 
					// CreateTable create table for models
 | 
				
			||||||
func (s *DB) CreateTable(values ...interface{}) *DB {
 | 
					func (s *DB) CreateTable(models ...interface{}) *DB {
 | 
				
			||||||
	db := s.clone()
 | 
						db := s.clone()
 | 
				
			||||||
	for _, value := range values {
 | 
						for _, model := range models {
 | 
				
			||||||
		db = db.NewScope(value).createTable().db
 | 
							db = db.NewScope(model).createTable().db
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return db
 | 
						return db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropTable drop table for models
 | 
				
			||||||
func (s *DB) DropTable(values ...interface{}) *DB {
 | 
					func (s *DB) DropTable(values ...interface{}) *DB {
 | 
				
			||||||
	db := s.clone()
 | 
						db := s.clone()
 | 
				
			||||||
	for _, value := range values {
 | 
						for _, value := range values {
 | 
				
			||||||
@ -393,18 +485,18 @@ func (s *DB) DropTable(values ...interface{}) *DB {
 | 
				
			|||||||
	return db
 | 
						return db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropTableIfExists drop table if it is exist
 | 
				
			||||||
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
 | 
					func (s *DB) DropTableIfExists(values ...interface{}) *DB {
 | 
				
			||||||
	db := s.clone()
 | 
						db := s.clone()
 | 
				
			||||||
	for _, value := range values {
 | 
						for _, value := range values {
 | 
				
			||||||
		if tableName, ok := value.(string); ok {
 | 
							if s.HasTable(value) {
 | 
				
			||||||
			db = db.Table(tableName)
 | 
								db.AddError(s.DropTable(value).Error)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		db = db.NewScope(value).dropTableIfExists().db
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return db
 | 
						return db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasTable check has table or not
 | 
				
			||||||
func (s *DB) HasTable(value interface{}) bool {
 | 
					func (s *DB) HasTable(value interface{}) bool {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		scope     = s.clone().NewScope(value)
 | 
							scope     = s.clone().NewScope(value)
 | 
				
			||||||
@ -417,69 +509,64 @@ func (s *DB) HasTable(value interface{}) bool {
 | 
				
			|||||||
		tableName = scope.TableName()
 | 
							tableName = scope.TableName()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	has := scope.Dialect().HasTable(scope, tableName)
 | 
						has := scope.Dialect().HasTable(tableName)
 | 
				
			||||||
	s.AddError(scope.db.Error)
 | 
						s.AddError(scope.db.Error)
 | 
				
			||||||
	return has
 | 
						return has
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
 | 
				
			||||||
func (s *DB) AutoMigrate(values ...interface{}) *DB {
 | 
					func (s *DB) AutoMigrate(values ...interface{}) *DB {
 | 
				
			||||||
	db := s.clone()
 | 
						db := s.clone()
 | 
				
			||||||
	for _, value := range values {
 | 
						for _, value := range values {
 | 
				
			||||||
		db = db.NewScope(value).NeedPtr().autoMigrate().db
 | 
							db = db.NewScope(value).autoMigrate().db
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return db
 | 
						return db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ModifyColumn modify column to type
 | 
				
			||||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
 | 
					func (s *DB) ModifyColumn(column string, typ string) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value)
 | 
				
			||||||
	scope.modifyColumn(column, typ)
 | 
						scope.modifyColumn(column, typ)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropColumn drop a column
 | 
				
			||||||
func (s *DB) DropColumn(column string) *DB {
 | 
					func (s *DB) DropColumn(column string) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value)
 | 
				
			||||||
	scope.dropColumn(column)
 | 
						scope.dropColumn(column)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) AddIndex(indexName string, column ...string) *DB {
 | 
					// AddIndex add index for columns with given name
 | 
				
			||||||
 | 
					func (s *DB) AddIndex(indexName string, columns ...string) *DB {
 | 
				
			||||||
	scope := s.Unscoped().NewScope(s.Value)
 | 
						scope := s.Unscoped().NewScope(s.Value)
 | 
				
			||||||
	scope.addIndex(false, indexName, column...)
 | 
						scope.addIndex(false, indexName, columns...)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
 | 
					// AddUniqueIndex add unique index for columns with given name
 | 
				
			||||||
 | 
					func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value)
 | 
				
			||||||
	scope.addIndex(true, indexName, column...)
 | 
						scope.addIndex(true, indexName, columns...)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RemoveIndex remove index with name
 | 
				
			||||||
func (s *DB) RemoveIndex(indexName string) *DB {
 | 
					func (s *DB) RemoveIndex(indexName string) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value)
 | 
				
			||||||
	scope.removeIndex(indexName)
 | 
						scope.removeIndex(indexName)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) CurrentDatabase() string {
 | 
					// AddForeignKey Add foreign key to the given scope, e.g:
 | 
				
			||||||
	var (
 | 
					//     db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
 | 
				
			||||||
		scope = s.clone().NewScope(s.Value)
 | 
					 | 
				
			||||||
		name  = s.dialect.CurrentDatabase(scope)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	return name
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/*
 | 
					 | 
				
			||||||
Add foreign key to the given scope
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Example:
 | 
					 | 
				
			||||||
	db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
 | 
					 | 
				
			||||||
*/
 | 
					 | 
				
			||||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
 | 
					func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
 | 
				
			||||||
	scope := s.clone().NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value)
 | 
				
			||||||
	scope.addForeignKey(field, dest, onDelete, onUpdate)
 | 
						scope.addForeignKey(field, dest, onDelete, onUpdate)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
 | 
				
			||||||
func (s *DB) Association(column string) *Association {
 | 
					func (s *DB) Association(column string) *Association {
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	scope := s.clone().NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value)
 | 
				
			||||||
@ -491,7 +578,7 @@ func (s *DB) Association(column string) *Association {
 | 
				
			|||||||
			if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
 | 
								if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
 | 
				
			||||||
				err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
 | 
									err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				return &Association{Scope: scope, Column: column, Field: field}
 | 
									return &Association{scope: scope, column: column, field: field}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
 | 
								err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
 | 
				
			||||||
@ -501,26 +588,30 @@ func (s *DB) Association(column string) *Association {
 | 
				
			|||||||
	return &Association{Error: err}
 | 
						return &Association{Error: err}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Preload preload associations with given conditions
 | 
				
			||||||
 | 
					//    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | 
				
			||||||
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
 | 
					func (s *DB) Preload(column string, conditions ...interface{}) *DB {
 | 
				
			||||||
	return s.clone().search.Preload(column, conditions...).db
 | 
						return s.clone().search.Preload(column, conditions...).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Set set value by name
 | 
					// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
 | 
				
			||||||
func (s *DB) Set(name string, value interface{}) *DB {
 | 
					func (s *DB) Set(name string, value interface{}) *DB {
 | 
				
			||||||
	return s.clone().InstantSet(name, value)
 | 
						return s.clone().InstantSet(name, value)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// InstantSet instant set setting, will affect current db
 | 
				
			||||||
func (s *DB) InstantSet(name string, value interface{}) *DB {
 | 
					func (s *DB) InstantSet(name string, value interface{}) *DB {
 | 
				
			||||||
	s.values[name] = value
 | 
						s.values[name] = value
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get get value by name
 | 
					// Get get setting by name
 | 
				
			||||||
func (s *DB) Get(name string) (value interface{}, ok bool) {
 | 
					func (s *DB) Get(name string) (value interface{}, ok bool) {
 | 
				
			||||||
	value, ok = s.values[name]
 | 
						value, ok = s.values[name]
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetJoinTableHandler set a model's join table handler for a relation
 | 
				
			||||||
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
 | 
					func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
 | 
				
			||||||
	scope := s.NewScope(source)
 | 
						scope := s.NewScope(source)
 | 
				
			||||||
	for _, field := range scope.GetModelStruct().StructFields {
 | 
						for _, field := range scope.GetModelStruct().StructFields {
 | 
				
			||||||
@ -530,7 +621,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
 | 
				
			|||||||
				destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
 | 
									destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
 | 
				
			||||||
				handler.Setup(field.Relationship, many2many, source, destination)
 | 
									handler.Setup(field.Relationship, many2many, source, destination)
 | 
				
			||||||
				field.Relationship.JoinTableHandler = handler
 | 
									field.Relationship.JoinTableHandler = handler
 | 
				
			||||||
				if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
 | 
									if table := handler.Table(s); scope.Dialect().HasTable(table) {
 | 
				
			||||||
					s.Table(table).AutoMigrate(handler)
 | 
										s.Table(table).AutoMigrate(handler)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -538,9 +629,10 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddError add error to the db
 | 
				
			||||||
func (s *DB) AddError(err error) error {
 | 
					func (s *DB) AddError(err error) error {
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		if err != RecordNotFound {
 | 
							if err != ErrRecordNotFound {
 | 
				
			||||||
			if s.logMode == 0 {
 | 
								if s.logMode == 0 {
 | 
				
			||||||
				go s.print(fileWithLineNum(), err)
 | 
									go s.print(fileWithLineNum(), err)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
@ -559,6 +651,7 @@ func (s *DB) AddError(err error) error {
 | 
				
			|||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GetErrors get happened errors from the db
 | 
				
			||||||
func (s *DB) GetErrors() (errors []error) {
 | 
					func (s *DB) GetErrors() (errors []error) {
 | 
				
			||||||
	if errs, ok := s.Error.(errorsInterface); ok {
 | 
						if errs, ok := s.Error.(errorsInterface); ok {
 | 
				
			||||||
		return errs.GetErrors()
 | 
							return errs.GetErrors()
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if s.search == nil {
 | 
						if s.search == nil {
 | 
				
			||||||
		db.search = &search{}
 | 
							db.search = &search{limit: -1, offset: -1}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		db.search = s.search.clone()
 | 
							db.search = s.search.clone()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										96
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										96
									
								
								main_test.go
									
									
									
									
									
								
							@ -4,23 +4,23 @@ import (
 | 
				
			|||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
	"database/sql/driver"
 | 
						"database/sql/driver"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_ "github.com/denisenkom/go-mssqldb"
 | 
					 | 
				
			||||||
	testdb "github.com/erikstmartin/go-testdb"
 | 
					 | 
				
			||||||
	_ "github.com/go-sql-driver/mysql"
 | 
					 | 
				
			||||||
	"github.com/jinzhu/gorm"
 | 
					 | 
				
			||||||
	"github.com/jinzhu/now"
 | 
					 | 
				
			||||||
	_ "github.com/lib/pq"
 | 
					 | 
				
			||||||
	_ "github.com/mattn/go-sqlite3"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/erikstmartin/go-testdb"
 | 
				
			||||||
 | 
						"github.com/jinzhu/gorm"
 | 
				
			||||||
 | 
						_ "github.com/jinzhu/gorm/dialects/mssql"
 | 
				
			||||||
 | 
						_ "github.com/jinzhu/gorm/dialects/mysql"
 | 
				
			||||||
 | 
						"github.com/jinzhu/gorm/dialects/postgres"
 | 
				
			||||||
 | 
						_ "github.com/jinzhu/gorm/dialects/sqlite"
 | 
				
			||||||
 | 
						"github.com/jinzhu/now"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	DB                 gorm.DB
 | 
						DB                 *gorm.DB
 | 
				
			||||||
	t1, t2, t3, t4, t5 time.Time
 | 
						t1, t2, t3, t4, t5 time.Time
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -42,7 +42,7 @@ func init() {
 | 
				
			|||||||
	runMigration()
 | 
						runMigration()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func OpenTestConnection() (db gorm.DB, err error) {
 | 
					func OpenTestConnection() (db *gorm.DB, err error) {
 | 
				
			||||||
	switch os.Getenv("GORM_DIALECT") {
 | 
						switch os.Getenv("GORM_DIALECT") {
 | 
				
			||||||
	case "mysql":
 | 
						case "mysql":
 | 
				
			||||||
		// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
 | 
							// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
 | 
				
			||||||
@ -115,7 +115,7 @@ func TestSetTable(t *testing.T) {
 | 
				
			|||||||
	DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
 | 
						DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
 | 
						if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
 | 
				
			||||||
		t.Errorf("No errors should happen if set table for pluck", err.Error())
 | 
							t.Error("No errors should happen if set table for pluck", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var users []User
 | 
						var users []User
 | 
				
			||||||
@ -376,7 +376,7 @@ func TestRows(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
						rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("Not error should happen, but got")
 | 
							t.Errorf("Not error should happen, got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	count := 0
 | 
						count := 0
 | 
				
			||||||
@ -386,8 +386,39 @@ func TestRows(t *testing.T) {
 | 
				
			|||||||
		rows.Scan(&name, &age)
 | 
							rows.Scan(&name, &age)
 | 
				
			||||||
		count++
 | 
							count++
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if count != 2 {
 | 
						if count != 2 {
 | 
				
			||||||
		t.Errorf("Should found two records with name 3")
 | 
							t.Errorf("Should found two records")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestScanRows(t *testing.T) {
 | 
				
			||||||
 | 
						user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
 | 
				
			||||||
 | 
						user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
 | 
				
			||||||
 | 
						user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
 | 
				
			||||||
 | 
						DB.Save(&user1).Save(&user2).Save(&user3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Not error should happen, got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type Result struct {
 | 
				
			||||||
 | 
							Name string
 | 
				
			||||||
 | 
							Age  int
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var results []Result
 | 
				
			||||||
 | 
						for rows.Next() {
 | 
				
			||||||
 | 
							var result Result
 | 
				
			||||||
 | 
							if err := DB.ScanRows(rows, &result); err != nil {
 | 
				
			||||||
 | 
								t.Errorf("should get no error, but got %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							results = append(results, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
 | 
				
			||||||
 | 
							t.Errorf("Should find expected results")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -448,7 +479,7 @@ func TestRaw(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
 | 
						DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
 | 
				
			||||||
	if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
 | 
						if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
 | 
				
			||||||
		t.Error("Raw sql to update records")
 | 
							t.Error("Raw sql to update records")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -470,14 +501,33 @@ func TestGroup(t *testing.T) {
 | 
				
			|||||||
func TestJoins(t *testing.T) {
 | 
					func TestJoins(t *testing.T) {
 | 
				
			||||||
	var user = User{
 | 
						var user = User{
 | 
				
			||||||
		Name:       "joins",
 | 
							Name:       "joins",
 | 
				
			||||||
 | 
							CreditCard: CreditCard{Number: "411111111111"},
 | 
				
			||||||
		Emails:     []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
 | 
							Emails:     []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	DB.Save(&user)
 | 
						DB.Save(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var result User
 | 
						var users1 []User
 | 
				
			||||||
	DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
 | 
						DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
 | 
				
			||||||
	if result.Name != "joins" || result.Id != user.Id {
 | 
						if len(users1) != 2 {
 | 
				
			||||||
		t.Errorf("Should find all two emails with Join")
 | 
							t.Errorf("should find two users using left join")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var users2 []User
 | 
				
			||||||
 | 
						DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
 | 
				
			||||||
 | 
						if len(users2) != 1 {
 | 
				
			||||||
 | 
							t.Errorf("should find one users using left join with conditions")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var users3 []User
 | 
				
			||||||
 | 
						DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
 | 
				
			||||||
 | 
						if len(users3) != 1 {
 | 
				
			||||||
 | 
							t.Errorf("should find one users using multiple left join conditions")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var users4 []User
 | 
				
			||||||
 | 
						DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
 | 
				
			||||||
 | 
						if len(users4) != 0 {
 | 
				
			||||||
 | 
							t.Errorf("should find no user when searching with unexisting credit card")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -557,7 +607,7 @@ func TestTimeWithZone(t *testing.T) {
 | 
				
			|||||||
		DB.First(&findUser, "name = ?", name)
 | 
							DB.First(&findUser, "name = ?", name)
 | 
				
			||||||
		foundBirthday = findUser.Birthday.UTC().Format(format)
 | 
							foundBirthday = findUser.Birthday.UTC().Format(format)
 | 
				
			||||||
		if foundBirthday != expectedBirthday {
 | 
							if foundBirthday != expectedBirthday {
 | 
				
			||||||
			t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
 | 
								t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
 | 
							if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
 | 
				
			||||||
@ -573,7 +623,7 @@ func TestTimeWithZone(t *testing.T) {
 | 
				
			|||||||
func TestHstore(t *testing.T) {
 | 
					func TestHstore(t *testing.T) {
 | 
				
			||||||
	type Details struct {
 | 
						type Details struct {
 | 
				
			||||||
		Id   int64
 | 
							Id   int64
 | 
				
			||||||
		Bulk gorm.Hstore
 | 
							Bulk postgres.Hstore
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
 | 
						if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
 | 
				
			||||||
@ -659,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var user User
 | 
						var user User
 | 
				
			||||||
	if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
 | 
						if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
 | 
				
			||||||
		t.Errorf("Should have found existing record")
 | 
							t.Errorf("Should have found existing record")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	scope := DB.NewScope(&Email{})
 | 
						scope := DB.NewScope(&Email{})
 | 
				
			||||||
	if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
 | 
						if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
 | 
				
			||||||
		t.Errorf("Email should have index idx_email_email")
 | 
							t.Errorf("Email should have index idx_email_email")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Got error when tried to remove index: %+v", err)
 | 
							t.Errorf("Got error when tried to remove index: %+v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
 | 
						if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
 | 
				
			||||||
		t.Errorf("Email's index idx_email_email should be deleted")
 | 
							t.Errorf("Email's index idx_email_email should be deleted")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Got error when tried to create index: %+v", err)
 | 
							t.Errorf("Got error when tried to create index: %+v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
 | 
						if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
				
			||||||
		t.Errorf("Email should have index idx_email_email_and_user_id")
 | 
							t.Errorf("Email should have index idx_email_email_and_user_id")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Got error when tried to remove index: %+v", err)
 | 
							t.Errorf("Got error when tried to remove index: %+v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
 | 
						if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
				
			||||||
		t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
 | 
							t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Got error when tried to create index: %+v", err)
 | 
							t.Errorf("Got error when tried to create index: %+v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
 | 
						if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
				
			||||||
		t.Errorf("Email should have index idx_email_email_and_user_id")
 | 
							t.Errorf("Email should have index idx_email_email_and_user_id")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Got error when tried to remove index: %+v", err)
 | 
							t.Errorf("Got error when tried to remove index: %+v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
 | 
						if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
				
			||||||
		t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
 | 
							t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) {
 | 
				
			|||||||
	DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
 | 
						DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	scope := DB.NewScope(&BigEmail{})
 | 
						scope := DB.NewScope(&BigEmail{})
 | 
				
			||||||
	if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
 | 
						if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
 | 
				
			||||||
		t.Errorf("Failed to create index")
 | 
							t.Errorf("Failed to create index")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
 | 
						if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
 | 
				
			||||||
		t.Errorf("Failed to create index")
 | 
							t.Errorf("Failed to create index")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								model.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								model.go
									
									
									
									
									
								
							@ -2,6 +2,10 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "time"
 | 
					import "time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
 | 
				
			||||||
 | 
					//    type User struct {
 | 
				
			||||||
 | 
					//      gorm.Model
 | 
				
			||||||
 | 
					//    }
 | 
				
			||||||
type Model struct {
 | 
					type Model struct {
 | 
				
			||||||
	ID        uint `gorm:"primary_key"`
 | 
						ID        uint `gorm:"primary_key"`
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time
 | 
				
			||||||
 | 
				
			|||||||
@ -3,10 +3,8 @@ package gorm
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"go/ast"
 | 
						"go/ast"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@ -14,6 +12,7 @@ import (
 | 
				
			|||||||
	"github.com/jinzhu/inflection"
 | 
						"github.com/jinzhu/inflection"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DefaultTableNameHandler default table name handler
 | 
				
			||||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
 | 
					var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
 | 
				
			||||||
	return defaultTableName
 | 
						return defaultTableName
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -41,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var modelStructsMap = newModelStructsMap()
 | 
					var modelStructsMap = newModelStructsMap()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ModelStruct model definition
 | 
				
			||||||
type ModelStruct struct {
 | 
					type ModelStruct struct {
 | 
				
			||||||
	PrimaryFields    []*StructField
 | 
						PrimaryFields    []*StructField
 | 
				
			||||||
	StructFields     []*StructField
 | 
						StructFields     []*StructField
 | 
				
			||||||
@ -48,10 +48,12 @@ type ModelStruct struct {
 | 
				
			|||||||
	defaultTableName string
 | 
						defaultTableName string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TableName get model's table name
 | 
				
			||||||
func (s *ModelStruct) TableName(db *DB) string {
 | 
					func (s *ModelStruct) TableName(db *DB) string {
 | 
				
			||||||
	return DefaultTableNameHandler(db, s.defaultTableName)
 | 
						return DefaultTableNameHandler(db, s.defaultTableName)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// StructField model field's struct definition
 | 
				
			||||||
type StructField struct {
 | 
					type StructField struct {
 | 
				
			||||||
	DBName          string
 | 
						DBName          string
 | 
				
			||||||
	Name            string
 | 
						Name            string
 | 
				
			||||||
@ -107,7 +109,7 @@ func getForeignField(column string, fields []*StructField) *StructField {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetModelStruct generate model struct & relationships based on struct and tag definition
 | 
					// GetModelStruct get value's model struct, relationships based on struct and tag definition
 | 
				
			||||||
func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
					func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
				
			||||||
	var modelStruct ModelStruct
 | 
						var modelStruct ModelStruct
 | 
				
			||||||
	// Scope value can't be nil
 | 
						// Scope value can't be nil
 | 
				
			||||||
@ -296,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
				
			|||||||
										if len(associationForeignKeys) == 0 {
 | 
															if len(associationForeignKeys) == 0 {
 | 
				
			||||||
											for _, foreignKey := range foreignKeys {
 | 
																for _, foreignKey := range foreignKeys {
 | 
				
			||||||
												if strings.HasPrefix(foreignKey, associationType) {
 | 
																	if strings.HasPrefix(foreignKey, associationType) {
 | 
				
			||||||
													associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
 | 
																		associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
 | 
				
			||||||
 | 
																		if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
 | 
				
			||||||
 | 
																			associationForeignKeys = append(associationForeignKeys, associationForeignKey)
 | 
				
			||||||
 | 
																		}
 | 
				
			||||||
												}
 | 
																	}
 | 
				
			||||||
											}
 | 
																}
 | 
				
			||||||
											if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
 | 
																if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
 | 
				
			||||||
@ -389,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
				
			|||||||
									if len(associationForeignKeys) == 0 {
 | 
														if len(associationForeignKeys) == 0 {
 | 
				
			||||||
										for _, foreignKey := range foreignKeys {
 | 
															for _, foreignKey := range foreignKeys {
 | 
				
			||||||
											if strings.HasPrefix(foreignKey, associationType) {
 | 
																if strings.HasPrefix(foreignKey, associationType) {
 | 
				
			||||||
												associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
 | 
																	associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
 | 
				
			||||||
 | 
																	if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
 | 
				
			||||||
 | 
																		associationForeignKeys = append(associationForeignKeys, associationForeignKey)
 | 
				
			||||||
 | 
																	}
 | 
				
			||||||
											}
 | 
																}
 | 
				
			||||||
										}
 | 
															}
 | 
				
			||||||
										if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
 | 
															if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
 | 
				
			||||||
@ -445,7 +453,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
				
			|||||||
									if len(associationForeignKeys) == 0 {
 | 
														if len(associationForeignKeys) == 0 {
 | 
				
			||||||
										for _, foreignKey := range foreignKeys {
 | 
															for _, foreignKey := range foreignKeys {
 | 
				
			||||||
											if strings.HasPrefix(foreignKey, field.Name) {
 | 
																if strings.HasPrefix(foreignKey, field.Name) {
 | 
				
			||||||
												associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name))
 | 
																	associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
 | 
				
			||||||
 | 
																	if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
 | 
				
			||||||
 | 
																		associationForeignKeys = append(associationForeignKeys, associationForeignKey)
 | 
				
			||||||
 | 
																	}
 | 
				
			||||||
											}
 | 
																}
 | 
				
			||||||
										}
 | 
															}
 | 
				
			||||||
										if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
 | 
															if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
 | 
				
			||||||
@ -508,63 +519,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
				
			|||||||
	return &modelStruct
 | 
						return &modelStruct
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GetStructFields get model's field structs
 | 
				
			||||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
 | 
					func (scope *Scope) GetStructFields() (fields []*StructField) {
 | 
				
			||||||
	return scope.GetModelStruct().StructFields
 | 
						return scope.GetModelStruct().StructFields
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) generateSqlTag(field *StructField) string {
 | 
					 | 
				
			||||||
	var sqlType string
 | 
					 | 
				
			||||||
	structType := field.Struct.Type
 | 
					 | 
				
			||||||
	if structType.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
		structType = structType.Elem()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	reflectValue := reflect.Indirect(reflect.New(structType))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if value, ok := field.TagSettings["TYPE"]; ok {
 | 
					 | 
				
			||||||
		sqlType = value
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
 | 
					 | 
				
			||||||
	if value, ok := field.TagSettings["DEFAULT"]; ok {
 | 
					 | 
				
			||||||
		additionalType = additionalType + " DEFAULT " + value
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if field.IsScanner {
 | 
					 | 
				
			||||||
		var getScannerValue func(reflect.Value)
 | 
					 | 
				
			||||||
		getScannerValue = func(value reflect.Value) {
 | 
					 | 
				
			||||||
			reflectValue = value
 | 
					 | 
				
			||||||
			if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
 | 
					 | 
				
			||||||
				getScannerValue(reflectValue.Field(0))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		getScannerValue(reflectValue)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if sqlType == "" {
 | 
					 | 
				
			||||||
		var size = 255
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if value, ok := field.TagSettings["SIZE"]; ok {
 | 
					 | 
				
			||||||
			size, _ = strconv.Atoi(value)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
 | 
					 | 
				
			||||||
		if field.IsPrimaryKey {
 | 
					 | 
				
			||||||
			autoIncrease = true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if v == "FALSE" {
 | 
					 | 
				
			||||||
			autoIncrease = false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if strings.TrimSpace(additionalType) == "" {
 | 
					 | 
				
			||||||
		return sqlType
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
 | 
					func parseTagSetting(tags reflect.StructTag) map[string]string {
 | 
				
			||||||
	setting := map[string]string{}
 | 
						setting := map[string]string{}
 | 
				
			||||||
	for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
 | 
						for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										80
									
								
								mssql.go
									
									
									
									
									
								
							
							
						
						
									
										80
									
								
								mssql.go
									
									
									
									
									
								
							@ -1,80 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type mssql struct {
 | 
					 | 
				
			||||||
	commonDialect
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (mssql) HasTop() bool {
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
					 | 
				
			||||||
	switch value.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Bool:
 | 
					 | 
				
			||||||
		return "bit"
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "int IDENTITY(1,1)"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "int"
 | 
					 | 
				
			||||||
	case reflect.Int64, reflect.Uint64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "bigint IDENTITY(1,1)"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "bigint"
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return "float"
 | 
					 | 
				
			||||||
	case reflect.String:
 | 
					 | 
				
			||||||
		if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("nvarchar(%d)", size)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "text"
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().(time.Time); ok {
 | 
					 | 
				
			||||||
			return "datetime2"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().([]byte); ok {
 | 
					 | 
				
			||||||
			if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
				return fmt.Sprintf("varchar(%d)", size)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return "text"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		count        int
 | 
					 | 
				
			||||||
		databaseName = s.CurrentDatabase(scope)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		count        int
 | 
					 | 
				
			||||||
		databaseName = s.CurrentDatabase(scope)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
 | 
					 | 
				
			||||||
	s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -21,7 +21,7 @@ type Tag struct {
 | 
				
			|||||||
	ID     uint   `gorm:"primary_key"`
 | 
						ID     uint   `gorm:"primary_key"`
 | 
				
			||||||
	Locale string `gorm:"primary_key"`
 | 
						Locale string `gorm:"primary_key"`
 | 
				
			||||||
	Value  string
 | 
						Value  string
 | 
				
			||||||
	Blogs  []*Blog `gorm:"many2many:"blogs_tags`
 | 
						Blogs  []*Blog `gorm:"many2many:blogs_tags"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func compareTags(tags []Tag, contents []string) bool {
 | 
					func compareTags(tags []Tag, contents []string) bool {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										70
									
								
								mysql.go
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								mysql.go
									
									
									
									
									
								
							@ -1,70 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type mysql struct {
 | 
					 | 
				
			||||||
	commonDialect
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
					 | 
				
			||||||
	switch value.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Bool:
 | 
					 | 
				
			||||||
		return "boolean"
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "int AUTO_INCREMENT"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "int"
 | 
					 | 
				
			||||||
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "int unsigned AUTO_INCREMENT"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "int unsigned"
 | 
					 | 
				
			||||||
	case reflect.Int64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "bigint AUTO_INCREMENT"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "bigint"
 | 
					 | 
				
			||||||
	case reflect.Uint64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "bigint unsigned AUTO_INCREMENT"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "bigint unsigned"
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return "double"
 | 
					 | 
				
			||||||
	case reflect.String:
 | 
					 | 
				
			||||||
		if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("varchar(%d)", size)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "longtext"
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().(time.Time); ok {
 | 
					 | 
				
			||||||
			return "timestamp NULL"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().([]byte); ok {
 | 
					 | 
				
			||||||
			if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
				return fmt.Sprintf("varbinary(%d)", size)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return "longblob"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (mysql) Quote(key string) string {
 | 
					 | 
				
			||||||
	return fmt.Sprintf("`%s`", key)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (mysql) SelectFromDummyTable() string {
 | 
					 | 
				
			||||||
	return "FROM DUAL"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
 | 
					 | 
				
			||||||
	s.RawScanString(scope, &name, "SELECT DATABASE()")
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -39,46 +39,46 @@ func TestPointerFields(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var nilPointerStruct = PointerStruct{}
 | 
						var nilPointerStruct = PointerStruct{}
 | 
				
			||||||
	if err := DB.Create(&nilPointerStruct).Error; err != nil {
 | 
						if err := DB.Create(&nilPointerStruct).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Failed to save nil pointer struct", err)
 | 
							t.Error("Failed to save nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var pointerStruct2 PointerStruct
 | 
						var pointerStruct2 PointerStruct
 | 
				
			||||||
	if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
 | 
						if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Failed to query saved nil pointer struct", err)
 | 
							t.Error("Failed to query saved nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var normalStruct2 NormalStruct
 | 
						var normalStruct2 NormalStruct
 | 
				
			||||||
	if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
 | 
						if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Failed to query saved nil pointer struct", err)
 | 
							t.Error("Failed to query saved nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var partialNilPointerStruct1 = PointerStruct{Num: &num}
 | 
						var partialNilPointerStruct1 = PointerStruct{Num: &num}
 | 
				
			||||||
	if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
 | 
						if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Failed to save partial nil pointer struct", err)
 | 
							t.Error("Failed to save partial nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var pointerStruct3 PointerStruct
 | 
						var pointerStruct3 PointerStruct
 | 
				
			||||||
	if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
 | 
						if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
 | 
				
			||||||
		t.Errorf("Failed to query saved partial nil pointer struct", err)
 | 
							t.Error("Failed to query saved partial nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var normalStruct3 NormalStruct
 | 
						var normalStruct3 NormalStruct
 | 
				
			||||||
	if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
 | 
						if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
 | 
				
			||||||
		t.Errorf("Failed to query saved partial pointer struct", err)
 | 
							t.Error("Failed to query saved partial pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var partialNilPointerStruct2 = PointerStruct{Name: &name}
 | 
						var partialNilPointerStruct2 = PointerStruct{Name: &name}
 | 
				
			||||||
	if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
 | 
						if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Failed to save partial nil pointer struct", err)
 | 
							t.Error("Failed to save partial nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var pointerStruct4 PointerStruct
 | 
						var pointerStruct4 PointerStruct
 | 
				
			||||||
	if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
 | 
						if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
 | 
				
			||||||
		t.Errorf("Failed to query saved partial nil pointer struct", err)
 | 
							t.Error("Failed to query saved partial nil pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var normalStruct4 NormalStruct
 | 
						var normalStruct4 NormalStruct
 | 
				
			||||||
	if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
 | 
						if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
 | 
				
			||||||
		t.Errorf("Failed to query saved partial pointer struct", err)
 | 
							t.Error("Failed to query saved partial pointer struct", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										154
									
								
								postgres.go
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								postgres.go
									
									
									
									
									
								
							@ -1,154 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"database/sql/driver"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/lib/pq/hstore"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type postgres struct {
 | 
					 | 
				
			||||||
	commonDialect
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (postgres) BinVar(i int) string {
 | 
					 | 
				
			||||||
	return fmt.Sprintf("$%v", i)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (postgres) SupportLastInsertId() bool {
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
					 | 
				
			||||||
	switch value.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Bool:
 | 
					 | 
				
			||||||
		return "boolean"
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "serial"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "integer"
 | 
					 | 
				
			||||||
	case reflect.Int64, reflect.Uint64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "bigserial"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "bigint"
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return "numeric"
 | 
					 | 
				
			||||||
	case reflect.String:
 | 
					 | 
				
			||||||
		if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("varchar(%d)", size)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "text"
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().(time.Time); ok {
 | 
					 | 
				
			||||||
			return "timestamp with time zone"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case reflect.Map:
 | 
					 | 
				
			||||||
		if value.Type() == hstoreType {
 | 
					 | 
				
			||||||
			return "hstore"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		if isByteArrayOrSlice(value) {
 | 
					 | 
				
			||||||
			return "bytea"
 | 
					 | 
				
			||||||
		} else if isUUID(value) {
 | 
					 | 
				
			||||||
			return "uuid"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var byteType = reflect.TypeOf(uint8(0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func isByteArrayOrSlice(value reflect.Value) bool {
 | 
					 | 
				
			||||||
	return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func isUUID(value reflect.Value) bool {
 | 
					 | 
				
			||||||
	if value.Kind() != reflect.Array || value.Type().Len() != 16 {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	typename := value.Type().Name()
 | 
					 | 
				
			||||||
	lower := strings.ToLower(typename)
 | 
					 | 
				
			||||||
	return "uuid" == lower || "guid" == lower
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s postgres) ReturningStr(tableName, key string) string {
 | 
					 | 
				
			||||||
	return fmt.Sprintf("RETURNING %v.%v", tableName, key)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s postgres) HasTable(scope *Scope, tableName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (postgres) RemoveIndex(scope *Scope, indexName string) {
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
 | 
					 | 
				
			||||||
	s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var hstoreType = reflect.TypeOf(Hstore{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Hstore map[string]*string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h Hstore) Value() (driver.Value, error) {
 | 
					 | 
				
			||||||
	hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
 | 
					 | 
				
			||||||
	if len(h) == 0 {
 | 
					 | 
				
			||||||
		return nil, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for key, value := range h {
 | 
					 | 
				
			||||||
		var s sql.NullString
 | 
					 | 
				
			||||||
		if value != nil {
 | 
					 | 
				
			||||||
			s.String = *value
 | 
					 | 
				
			||||||
			s.Valid = true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		hstore.Map[key] = s
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return hstore.Value()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *Hstore) Scan(value interface{}) error {
 | 
					 | 
				
			||||||
	hstore := hstore.Hstore{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := hstore.Scan(value); err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(hstore.Map) == 0 {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	*h = Hstore{}
 | 
					 | 
				
			||||||
	for k := range hstore.Map {
 | 
					 | 
				
			||||||
		if hstore.Map[k].Valid {
 | 
					 | 
				
			||||||
			s := hstore.Map[k].String
 | 
					 | 
				
			||||||
			(*h)[k] = &s
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			(*h)[k] = nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										388
									
								
								preload.go
									
									
									
									
									
								
							
							
						
						
									
										388
									
								
								preload.go
									
									
									
									
									
								
							@ -1,388 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"database/sql/driver"
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
 | 
					 | 
				
			||||||
	// If value is a nil pointer, Indirect returns a zero Value!
 | 
					 | 
				
			||||||
	// Therefor we need to check for a zero value,
 | 
					 | 
				
			||||||
	// as FieldByName could panic
 | 
					 | 
				
			||||||
	if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
 | 
					 | 
				
			||||||
		for _, column := range columns {
 | 
					 | 
				
			||||||
			if pointedValue.FieldByName(column).IsValid() {
 | 
					 | 
				
			||||||
				result := pointedValue.FieldByName(column).Interface()
 | 
					 | 
				
			||||||
				if r, ok := result.(driver.Valuer); ok {
 | 
					 | 
				
			||||||
					result, _ = r.Value()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				results = append(results, result)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func equalAsString(a interface{}, b interface{}) bool {
 | 
					 | 
				
			||||||
	return toString(a) == toString(b)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func Preload(scope *Scope) {
 | 
					 | 
				
			||||||
	if scope.Search.preload == nil || scope.HasError() {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	preloadMap := map[string]bool{}
 | 
					 | 
				
			||||||
	fields := scope.Fields()
 | 
					 | 
				
			||||||
	for _, preload := range scope.Search.preload {
 | 
					 | 
				
			||||||
		schema, conditions := preload.schema, preload.conditions
 | 
					 | 
				
			||||||
		keys := strings.Split(schema, ".")
 | 
					 | 
				
			||||||
		currentScope := scope
 | 
					 | 
				
			||||||
		currentFields := fields
 | 
					 | 
				
			||||||
		originalConditions := conditions
 | 
					 | 
				
			||||||
		conditions = []interface{}{}
 | 
					 | 
				
			||||||
		for i, key := range keys {
 | 
					 | 
				
			||||||
			var found bool
 | 
					 | 
				
			||||||
			if preloadMap[strings.Join(keys[:i+1], ".")] {
 | 
					 | 
				
			||||||
				goto nextLoop
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if i == len(keys)-1 {
 | 
					 | 
				
			||||||
				conditions = originalConditions
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for _, field := range currentFields {
 | 
					 | 
				
			||||||
				if field.Name != key || field.Relationship == nil {
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				found = true
 | 
					 | 
				
			||||||
				switch field.Relationship.Kind {
 | 
					 | 
				
			||||||
				case "has_one":
 | 
					 | 
				
			||||||
					currentScope.handleHasOnePreload(field, conditions)
 | 
					 | 
				
			||||||
				case "has_many":
 | 
					 | 
				
			||||||
					currentScope.handleHasManyPreload(field, conditions)
 | 
					 | 
				
			||||||
				case "belongs_to":
 | 
					 | 
				
			||||||
					currentScope.handleBelongsToPreload(field, conditions)
 | 
					 | 
				
			||||||
				case "many_to_many":
 | 
					 | 
				
			||||||
					currentScope.handleManyToManyPreload(field, conditions)
 | 
					 | 
				
			||||||
				default:
 | 
					 | 
				
			||||||
					currentScope.Err(errors.New("not supported relation"))
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				break
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if !found {
 | 
					 | 
				
			||||||
				value := reflect.ValueOf(currentScope.Value)
 | 
					 | 
				
			||||||
				if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
 | 
					 | 
				
			||||||
					value = value.Index(0).Elem()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			preloadMap[strings.Join(keys[:i+1], ".")] = true
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		nextLoop:
 | 
					 | 
				
			||||||
			if i < len(keys)-1 {
 | 
					 | 
				
			||||||
				currentScope = currentScope.getColumnsAsScope(key)
 | 
					 | 
				
			||||||
				currentFields = currentScope.Fields()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func makeSlice(typ reflect.Type) interface{} {
 | 
					 | 
				
			||||||
	if typ.Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
		typ = typ.Elem()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	sliceType := reflect.SliceOf(typ)
 | 
					 | 
				
			||||||
	slice := reflect.New(sliceType)
 | 
					 | 
				
			||||||
	slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
 | 
					 | 
				
			||||||
	return slice.Interface()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
 | 
					 | 
				
			||||||
	relation := field.Relationship
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
 | 
					 | 
				
			||||||
	if len(primaryKeys) == 0 {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results := makeSlice(field.Struct.Type)
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
 | 
					 | 
				
			||||||
	resultValues := reflect.Indirect(reflect.ValueOf(results))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i := 0; i < resultValues.Len(); i++ {
 | 
					 | 
				
			||||||
		result := resultValues.Index(i)
 | 
					 | 
				
			||||||
		if scope.IndirectValue().Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
			value := getRealValue(result, relation.ForeignFieldNames)
 | 
					 | 
				
			||||||
			objects := scope.IndirectValue()
 | 
					 | 
				
			||||||
			for j := 0; j < objects.Len(); j++ {
 | 
					 | 
				
			||||||
				if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
 | 
					 | 
				
			||||||
					reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
 | 
					 | 
				
			||||||
					break
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			if err := scope.SetColumn(field, result); err != nil {
 | 
					 | 
				
			||||||
				scope.Err(err)
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
 | 
					 | 
				
			||||||
	relation := field.Relationship
 | 
					 | 
				
			||||||
	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
 | 
					 | 
				
			||||||
	if len(primaryKeys) == 0 {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results := makeSlice(field.Struct.Type)
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
 | 
					 | 
				
			||||||
	resultValues := reflect.Indirect(reflect.ValueOf(results))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if scope.IndirectValue().Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
		preloadMap := make(map[string][]reflect.Value)
 | 
					 | 
				
			||||||
		for i := 0; i < resultValues.Len(); i++ {
 | 
					 | 
				
			||||||
			result := resultValues.Index(i)
 | 
					 | 
				
			||||||
			value := getRealValue(result, relation.ForeignFieldNames)
 | 
					 | 
				
			||||||
			preloadMap[toString(value)] = append(preloadMap[toString(value)], result)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		objects := scope.IndirectValue()
 | 
					 | 
				
			||||||
		for j := 0; j < objects.Len(); j++ {
 | 
					 | 
				
			||||||
			object := reflect.Indirect(objects.Index(j))
 | 
					 | 
				
			||||||
			objectRealValue := getRealValue(object, relation.AssociationForeignFieldNames)
 | 
					 | 
				
			||||||
			objectStringValue := toString(objectRealValue)
 | 
					 | 
				
			||||||
			if results, ok := preloadMap[objectStringValue]; ok {
 | 
					 | 
				
			||||||
				if object.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
					object = object.Elem()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				f := object.FieldByName(field.Name)
 | 
					 | 
				
			||||||
				f.Set(reflect.Append(f, results...))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		scope.SetColumn(field, resultValues)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
 | 
					 | 
				
			||||||
	relation := field.Relationship
 | 
					 | 
				
			||||||
	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
 | 
					 | 
				
			||||||
	if len(primaryKeys) == 0 {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	results := makeSlice(field.Struct.Type)
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
 | 
					 | 
				
			||||||
	resultValues := reflect.Indirect(reflect.ValueOf(results))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i := 0; i < resultValues.Len(); i++ {
 | 
					 | 
				
			||||||
		result := resultValues.Index(i)
 | 
					 | 
				
			||||||
		if scope.IndirectValue().Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
			value := getRealValue(result, relation.AssociationForeignFieldNames)
 | 
					 | 
				
			||||||
			objects := scope.IndirectValue()
 | 
					 | 
				
			||||||
			for j := 0; j < objects.Len(); j++ {
 | 
					 | 
				
			||||||
				object := reflect.Indirect(objects.Index(j))
 | 
					 | 
				
			||||||
				if object.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
					object = reflect.Indirect(objects.Index(j).Elem())
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
 | 
					 | 
				
			||||||
					object.FieldByName(field.Name).Set(result)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			scope.SetColumn(field, result)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
 | 
					 | 
				
			||||||
	relation := field.Relationship
 | 
					 | 
				
			||||||
	joinTableHandler := relation.JoinTableHandler
 | 
					 | 
				
			||||||
	destType := field.StructField.Struct.Type.Elem()
 | 
					 | 
				
			||||||
	var isPtr bool
 | 
					 | 
				
			||||||
	if destType.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
		isPtr = true
 | 
					 | 
				
			||||||
		destType = destType.Elem()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var sourceKeys []string
 | 
					 | 
				
			||||||
	var linkHash = make(map[string][]reflect.Value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, key := range joinTableHandler.SourceForeignKeys() {
 | 
					 | 
				
			||||||
		sourceKeys = append(sourceKeys, key.DBName)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(conditions) > 0 {
 | 
					 | 
				
			||||||
		preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	rows, err := preloadJoinDB.Rows()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if scope.Err(err) != nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	defer rows.Close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	columns, _ := rows.Columns()
 | 
					 | 
				
			||||||
	for rows.Next() {
 | 
					 | 
				
			||||||
		elem := reflect.New(destType).Elem()
 | 
					 | 
				
			||||||
		var values = make([]interface{}, len(columns))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		fields := scope.New(elem.Addr().Interface()).Fields()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		var foundFields = map[string]bool{}
 | 
					 | 
				
			||||||
		for index, column := range columns {
 | 
					 | 
				
			||||||
			if field, ok := fields[column]; ok && !foundFields[column] {
 | 
					 | 
				
			||||||
				if field.Field.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
					values[index] = field.Field.Addr().Interface()
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				foundFields[column] = true
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				var i interface{}
 | 
					 | 
				
			||||||
				values[index] = &i
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		scope.Err(rows.Scan(values...))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		var sourceKey []interface{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		var scannedFields = map[string]bool{}
 | 
					 | 
				
			||||||
		for index, column := range columns {
 | 
					 | 
				
			||||||
			value := values[index]
 | 
					 | 
				
			||||||
			if field, ok := fields[column]; ok && !scannedFields[column] {
 | 
					 | 
				
			||||||
				if field.Field.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
					field.Field.Set(reflect.ValueOf(value).Elem())
 | 
					 | 
				
			||||||
				} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
 | 
					 | 
				
			||||||
					field.Field.Set(v)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				scannedFields[column] = true
 | 
					 | 
				
			||||||
			} else if strInSlice(column, sourceKeys) {
 | 
					 | 
				
			||||||
				sourceKey = append(sourceKey, *(value.(*interface{})))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if len(sourceKey) != 0 {
 | 
					 | 
				
			||||||
			if isPtr {
 | 
					 | 
				
			||||||
				linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var foreignFieldNames []string
 | 
					 | 
				
			||||||
	for _, dbName := range relation.ForeignFieldNames {
 | 
					 | 
				
			||||||
		if field, ok := scope.FieldByName(dbName); ok {
 | 
					 | 
				
			||||||
			foreignFieldNames = append(foreignFieldNames, field.Name)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if scope.IndirectValue().Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
		objects := scope.IndirectValue()
 | 
					 | 
				
			||||||
		for j := 0; j < objects.Len(); j++ {
 | 
					 | 
				
			||||||
			object := reflect.Indirect(objects.Index(j))
 | 
					 | 
				
			||||||
			if object.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
				object = object.Elem()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			source := getRealValue(object, foreignFieldNames)
 | 
					 | 
				
			||||||
			field := object.FieldByName(field.Name)
 | 
					 | 
				
			||||||
			for _, link := range linkHash[toString(source)] {
 | 
					 | 
				
			||||||
				field.Set(reflect.Append(field, link))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		if object := scope.IndirectValue(); object.IsValid() {
 | 
					 | 
				
			||||||
			source := getRealValue(object, foreignFieldNames)
 | 
					 | 
				
			||||||
			field := object.FieldByName(field.Name)
 | 
					 | 
				
			||||||
			for _, link := range linkHash[toString(source)] {
 | 
					 | 
				
			||||||
				field.Set(reflect.Append(field, link))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
 | 
					 | 
				
			||||||
	values := scope.IndirectValue()
 | 
					 | 
				
			||||||
	switch values.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Slice:
 | 
					 | 
				
			||||||
		for i := 0; i < values.Len(); i++ {
 | 
					 | 
				
			||||||
			var result []interface{}
 | 
					 | 
				
			||||||
			for _, column := range columns {
 | 
					 | 
				
			||||||
				value := reflect.Indirect(values.Index(i))
 | 
					 | 
				
			||||||
				if value.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
					value = reflect.Indirect(values.Index(i).Elem())
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				result = append(result, value.FieldByName(column).Interface())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			results = append(results, result)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		var result []interface{}
 | 
					 | 
				
			||||||
		for _, column := range columns {
 | 
					 | 
				
			||||||
			result = append(result, values.FieldByName(column).Interface())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return [][]interface{}{result}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) getColumnsAsScope(column string) *Scope {
 | 
					 | 
				
			||||||
	values := scope.IndirectValue()
 | 
					 | 
				
			||||||
	switch values.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Slice:
 | 
					 | 
				
			||||||
		modelType := values.Type().Elem()
 | 
					 | 
				
			||||||
		if modelType.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
			modelType = modelType.Elem()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		fieldStruct, _ := modelType.FieldByName(column)
 | 
					 | 
				
			||||||
		var columns reflect.Value
 | 
					 | 
				
			||||||
		if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
			columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		for i := 0; i < values.Len(); i++ {
 | 
					 | 
				
			||||||
			column := reflect.Indirect(values.Index(i)).FieldByName(column)
 | 
					 | 
				
			||||||
			if column.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
				column = column.Elem()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if column.Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
				for i := 0; i < column.Len(); i++ {
 | 
					 | 
				
			||||||
					elem := column.Index(i)
 | 
					 | 
				
			||||||
					if elem.CanAddr() {
 | 
					 | 
				
			||||||
						columns = reflect.Append(columns, elem.Addr())
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				if column.CanAddr() {
 | 
					 | 
				
			||||||
					columns = reflect.Append(columns, column.Addr())
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return scope.New(columns.Interface())
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		field := values.FieldByName(column)
 | 
					 | 
				
			||||||
		if !field.CanAddr() {
 | 
					 | 
				
			||||||
			return nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return scope.New(field.Addr().Interface())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										251
									
								
								preload_test.go
									
									
									
									
									
								
							
							
						
						
									
										251
									
								
								preload_test.go
									
									
									
									
									
								
							@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
							t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
 | 
						if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
 | 
				
			||||||
		t.Error(err)
 | 
							t.Error(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -818,90 +818,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestManyToManyPreloadForPointer(t *testing.T) {
 | 
					 | 
				
			||||||
	type (
 | 
					 | 
				
			||||||
		Level1 struct {
 | 
					 | 
				
			||||||
			ID    uint
 | 
					 | 
				
			||||||
			Value string
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		Level2 struct {
 | 
					 | 
				
			||||||
			ID      uint
 | 
					 | 
				
			||||||
			Value   string
 | 
					 | 
				
			||||||
			Level1s []*Level1 `gorm:"many2many:levels;"`
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	DB.DropTableIfExists(&Level2{})
 | 
					 | 
				
			||||||
	DB.DropTableIfExists(&Level1{})
 | 
					 | 
				
			||||||
	DB.DropTableIfExists("levels")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	want := Level2{Value: "Bob", Level1s: []*Level1{
 | 
					 | 
				
			||||||
		{Value: "ru"},
 | 
					 | 
				
			||||||
		{Value: "en"},
 | 
					 | 
				
			||||||
	}}
 | 
					 | 
				
			||||||
	if err := DB.Save(&want).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	want2 := Level2{Value: "Tom", Level1s: []*Level1{
 | 
					 | 
				
			||||||
		{Value: "zh"},
 | 
					 | 
				
			||||||
		{Value: "de"},
 | 
					 | 
				
			||||||
	}}
 | 
					 | 
				
			||||||
	if err := DB.Save(&want2).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var got Level2
 | 
					 | 
				
			||||||
	if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(got, want) {
 | 
					 | 
				
			||||||
		t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var got2 Level2
 | 
					 | 
				
			||||||
	if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(got2, want2) {
 | 
					 | 
				
			||||||
		t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var got3 []Level2
 | 
					 | 
				
			||||||
	if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(got3, []Level2{got, got2}) {
 | 
					 | 
				
			||||||
		t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var got4 []Level2
 | 
					 | 
				
			||||||
	if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
 | 
					 | 
				
			||||||
		t.Error(err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var got5 Level2
 | 
					 | 
				
			||||||
	DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var ruLevel1 Level1
 | 
					 | 
				
			||||||
	var zhLevel1 Level1
 | 
					 | 
				
			||||||
	DB.First(&ruLevel1, "value = ?", "ru")
 | 
					 | 
				
			||||||
	DB.First(&zhLevel1, "value = ?", "zh")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	got.Level1s = []*Level1{&ruLevel1}
 | 
					 | 
				
			||||||
	got2.Level1s = []*Level1{&zhLevel1}
 | 
					 | 
				
			||||||
	if !reflect.DeepEqual(got4, []Level2{got, got2}) {
 | 
					 | 
				
			||||||
		t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestManyToManyPreloadForNestedPointer(t *testing.T) {
 | 
					func TestManyToManyPreloadForNestedPointer(t *testing.T) {
 | 
				
			||||||
	type (
 | 
						type (
 | 
				
			||||||
		Level1 struct {
 | 
							Level1 struct {
 | 
				
			||||||
@ -1065,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
							t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
 | 
						if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
 | 
				
			||||||
		t.Error(err)
 | 
							t.Error(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -1122,12 +1038,87 @@ func TestNestedManyToManyPreload2(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
							t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
 | 
						if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
 | 
				
			||||||
		t.Error(err)
 | 
							t.Error(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNestedManyToManyPreload3(t *testing.T) {
 | 
					func TestNestedManyToManyPreload3(t *testing.T) {
 | 
				
			||||||
 | 
						type (
 | 
				
			||||||
 | 
							Level1 struct {
 | 
				
			||||||
 | 
								ID    uint
 | 
				
			||||||
 | 
								Value string
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							Level2 struct {
 | 
				
			||||||
 | 
								ID      uint
 | 
				
			||||||
 | 
								Value   string
 | 
				
			||||||
 | 
								Level1s []*Level1 `gorm:"many2many:level1_level2;"`
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							Level3 struct {
 | 
				
			||||||
 | 
								ID       uint
 | 
				
			||||||
 | 
								Value    string
 | 
				
			||||||
 | 
								Level2ID sql.NullInt64
 | 
				
			||||||
 | 
								Level2   *Level2
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.DropTableIfExists(&Level1{})
 | 
				
			||||||
 | 
						DB.DropTableIfExists(&Level2{})
 | 
				
			||||||
 | 
						DB.DropTableIfExists(&Level3{})
 | 
				
			||||||
 | 
						DB.DropTableIfExists("level1_level2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						level1Zh := &Level1{Value: "zh"}
 | 
				
			||||||
 | 
						level1Ru := &Level1{Value: "ru"}
 | 
				
			||||||
 | 
						level1En := &Level1{Value: "en"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						level21 := &Level2{
 | 
				
			||||||
 | 
							Value:   "Level2-1",
 | 
				
			||||||
 | 
							Level1s: []*Level1{level1Zh, level1Ru},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						level22 := &Level2{
 | 
				
			||||||
 | 
							Value:   "Level2-2",
 | 
				
			||||||
 | 
							Level1s: []*Level1{level1Zh, level1En},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wants := []*Level3{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Value:  "Level3-1",
 | 
				
			||||||
 | 
								Level2: level21,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Value:  "Level3-2",
 | 
				
			||||||
 | 
								Level2: level22,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Value:  "Level3-3",
 | 
				
			||||||
 | 
								Level2: level21,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, want := range wants {
 | 
				
			||||||
 | 
							if err := DB.Save(&want).Error; err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var gots []*Level3
 | 
				
			||||||
 | 
						if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
 | 
				
			||||||
 | 
							return db.Order("level1.id ASC")
 | 
				
			||||||
 | 
						}).Find(&gots).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(gots, wants) {
 | 
				
			||||||
 | 
							t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestNestedManyToManyPreload4(t *testing.T) {
 | 
				
			||||||
	type (
 | 
						type (
 | 
				
			||||||
		Level4 struct {
 | 
							Level4 struct {
 | 
				
			||||||
			ID       uint
 | 
								ID       uint
 | 
				
			||||||
@ -1185,6 +1176,90 @@ func TestNestedManyToManyPreload3(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestManyToManyPreloadForPointer(t *testing.T) {
 | 
				
			||||||
 | 
						type (
 | 
				
			||||||
 | 
							Level1 struct {
 | 
				
			||||||
 | 
								ID    uint
 | 
				
			||||||
 | 
								Value string
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							Level2 struct {
 | 
				
			||||||
 | 
								ID      uint
 | 
				
			||||||
 | 
								Value   string
 | 
				
			||||||
 | 
								Level1s []*Level1 `gorm:"many2many:levels;"`
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.DropTableIfExists(&Level2{})
 | 
				
			||||||
 | 
						DB.DropTableIfExists(&Level1{})
 | 
				
			||||||
 | 
						DB.DropTableIfExists("levels")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						want := Level2{Value: "Bob", Level1s: []*Level1{
 | 
				
			||||||
 | 
							{Value: "ru"},
 | 
				
			||||||
 | 
							{Value: "en"},
 | 
				
			||||||
 | 
						}}
 | 
				
			||||||
 | 
						if err := DB.Save(&want).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						want2 := Level2{Value: "Tom", Level1s: []*Level1{
 | 
				
			||||||
 | 
							{Value: "zh"},
 | 
				
			||||||
 | 
							{Value: "de"},
 | 
				
			||||||
 | 
						}}
 | 
				
			||||||
 | 
						if err := DB.Save(&want2).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var got Level2
 | 
				
			||||||
 | 
						if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(got, want) {
 | 
				
			||||||
 | 
							t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var got2 Level2
 | 
				
			||||||
 | 
						if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(got2, want2) {
 | 
				
			||||||
 | 
							t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var got3 []Level2
 | 
				
			||||||
 | 
						if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(got3, []Level2{got, got2}) {
 | 
				
			||||||
 | 
							t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var got4 []Level2
 | 
				
			||||||
 | 
						if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var got5 Level2
 | 
				
			||||||
 | 
						DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var ruLevel1 Level1
 | 
				
			||||||
 | 
						var zhLevel1 Level1
 | 
				
			||||||
 | 
						DB.First(&ruLevel1, "value = ?", "ru")
 | 
				
			||||||
 | 
						DB.First(&zhLevel1, "value = ?", "zh")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						got.Level1s = []*Level1{&ruLevel1}
 | 
				
			||||||
 | 
						got2.Level1s = []*Level1{&zhLevel1}
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(got4, []Level2{got, got2}) {
 | 
				
			||||||
 | 
							t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNilPointerSlice(t *testing.T) {
 | 
					func TestNilPointerSlice(t *testing.T) {
 | 
				
			||||||
	type (
 | 
						type (
 | 
				
			||||||
		Level3 struct {
 | 
							Level3 struct {
 | 
				
			||||||
@ -1234,7 +1309,7 @@ func TestNilPointerSlice(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(got) != 2 {
 | 
						if len(got) != 2 {
 | 
				
			||||||
		t.Error("got %v items, expected 2", len(got))
 | 
							t.Errorf("got %v items, expected 2", len(got))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
 | 
						if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
 | 
				
			||||||
 | 
				
			|||||||
@ -629,14 +629,3 @@ func TestSelectWithArrayInput(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Should have selected both age and name")
 | 
							t.Errorf("Should have selected both age and name")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestCurrentDatabase(t *testing.T) {
 | 
					 | 
				
			||||||
	databaseName := DB.CurrentDatabase()
 | 
					 | 
				
			||||||
	if err := DB.Error; err != nil {
 | 
					 | 
				
			||||||
		t.Errorf("Problem getting current db name: %s", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if databaseName == "" {
 | 
					 | 
				
			||||||
		t.Errorf("Current db name returned empty; this should never happen!")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	t.Logf("Got current db name: %v", databaseName)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										312
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										312
									
								
								scope.go
									
									
									
									
									
								
							@ -1,48 +1,32 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Scope contain current operation's information when you perform any operation on the database
 | 
				
			||||||
type Scope struct {
 | 
					type Scope struct {
 | 
				
			||||||
	Search          *search
 | 
						Search          *search
 | 
				
			||||||
	Value           interface{}
 | 
						Value           interface{}
 | 
				
			||||||
	Sql             string
 | 
						SQL             string
 | 
				
			||||||
	SqlVars         []interface{}
 | 
						SQLVars         []interface{}
 | 
				
			||||||
	db              *DB
 | 
						db              *DB
 | 
				
			||||||
	indirectValue   *reflect.Value
 | 
						instanceID      string
 | 
				
			||||||
	instanceId      string
 | 
					 | 
				
			||||||
	primaryKeyField *Field
 | 
						primaryKeyField *Field
 | 
				
			||||||
	skipLeft        bool
 | 
						skipLeft        bool
 | 
				
			||||||
	fields          map[string]*Field
 | 
						fields          map[string]*Field
 | 
				
			||||||
	selectAttrs     *[]string
 | 
						selectAttrs     *[]string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IndirectValue return scope's reflect value's indirect value
 | 
				
			||||||
func (scope *Scope) IndirectValue() reflect.Value {
 | 
					func (scope *Scope) IndirectValue() reflect.Value {
 | 
				
			||||||
	if scope.indirectValue == nil {
 | 
						return indirect(reflect.ValueOf(scope.Value))
 | 
				
			||||||
		value := reflect.Indirect(reflect.ValueOf(scope.Value))
 | 
					 | 
				
			||||||
		if value.Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
			value = value.Elem()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		scope.indirectValue = &value
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return *scope.indirectValue
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) NeedPtr() *Scope {
 | 
					 | 
				
			||||||
	reflectKind := reflect.ValueOf(scope.Value).Kind()
 | 
					 | 
				
			||||||
	if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
 | 
					 | 
				
			||||||
		err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
 | 
					 | 
				
			||||||
		scope.Err(err)
 | 
					 | 
				
			||||||
		fmt.Printf(err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return scope
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// New create a new Scope without search information
 | 
					// New create a new Scope without search information
 | 
				
			||||||
@ -61,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DB return scope's DB connection
 | 
				
			||||||
func (scope *Scope) DB() *DB {
 | 
					func (scope *Scope) DB() *DB {
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SqlDB return *sql.DB
 | 
					// SQLDB return *sql.DB
 | 
				
			||||||
func (scope *Scope) SqlDB() sqlCommon {
 | 
					func (scope *Scope) SQLDB() sqlCommon {
 | 
				
			||||||
	return scope.db.db
 | 
						return scope.db.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -75,7 +60,7 @@ func (scope *Scope) SkipLeft() {
 | 
				
			|||||||
	scope.skipLeft = true
 | 
						scope.skipLeft = true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Quote used to quote database column name according to database dialect
 | 
					// Quote used to quote string to escape them for database
 | 
				
			||||||
func (scope *Scope) Quote(str string) string {
 | 
					func (scope *Scope) Quote(str string) string {
 | 
				
			||||||
	if strings.Index(str, ".") != -1 {
 | 
						if strings.Index(str, ".") != -1 {
 | 
				
			||||||
		newStrs := []string{}
 | 
							newStrs := []string{}
 | 
				
			||||||
@ -83,12 +68,12 @@ func (scope *Scope) Quote(str string) string {
 | 
				
			|||||||
			newStrs = append(newStrs, scope.Dialect().Quote(str))
 | 
								newStrs = append(newStrs, scope.Dialect().Quote(str))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return strings.Join(newStrs, ".")
 | 
							return strings.Join(newStrs, ".")
 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		return scope.Dialect().Quote(str)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) QuoteIfPossible(str string) string {
 | 
						return scope.Dialect().Quote(str)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) quoteIfPossible(str string) string {
 | 
				
			||||||
	if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
 | 
						if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
 | 
				
			||||||
		return scope.Quote(str)
 | 
							return scope.Quote(str)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -100,7 +85,7 @@ func (scope *Scope) Dialect() Dialect {
 | 
				
			|||||||
	return scope.db.parent.dialect
 | 
						return scope.db.parent.dialect
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Err write error
 | 
					// Err add error to Scope
 | 
				
			||||||
func (scope *Scope) Err(err error) error {
 | 
					func (scope *Scope) Err(err error) error {
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		scope.db.AddError(err)
 | 
							scope.db.AddError(err)
 | 
				
			||||||
@ -118,27 +103,30 @@ func (scope *Scope) HasError() bool {
 | 
				
			|||||||
	return scope.db.Error != nil
 | 
						return scope.db.Error != nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) PrimaryFields() []*Field {
 | 
					// PrimaryFields return scope's primary fields
 | 
				
			||||||
	var fields = []*Field{}
 | 
					func (scope *Scope) PrimaryFields() (fields []*Field) {
 | 
				
			||||||
	for _, field := range scope.GetModelStruct().PrimaryFields {
 | 
						for _, field := range scope.Fields() {
 | 
				
			||||||
		fields = append(fields, scope.Fields()[field.DBName])
 | 
							if field.IsPrimaryKey {
 | 
				
			||||||
 | 
								fields = append(fields, field)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return fields
 | 
						return fields
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
 | 
				
			||||||
func (scope *Scope) PrimaryField() *Field {
 | 
					func (scope *Scope) PrimaryField() *Field {
 | 
				
			||||||
	if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
 | 
						if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
 | 
				
			||||||
		if len(primaryFields) > 1 {
 | 
							if len(primaryFields) > 1 {
 | 
				
			||||||
			if field, ok := scope.Fields()["id"]; ok {
 | 
								if field, ok := scope.FieldByName("id"); ok {
 | 
				
			||||||
				return field
 | 
									return field
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return scope.Fields()[primaryFields[0].DBName]
 | 
							return scope.PrimaryFields()[0]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// PrimaryKey get the primary key's column name
 | 
					// PrimaryKey get main primary field's db name
 | 
				
			||||||
func (scope *Scope) PrimaryKey() string {
 | 
					func (scope *Scope) PrimaryKey() string {
 | 
				
			||||||
	if field := scope.PrimaryField(); field != nil {
 | 
						if field := scope.PrimaryField(); field != nil {
 | 
				
			||||||
		return field.DBName
 | 
							return field.DBName
 | 
				
			||||||
@ -146,7 +134,7 @@ func (scope *Scope) PrimaryKey() string {
 | 
				
			|||||||
	return ""
 | 
						return ""
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// PrimaryKeyZero check the primary key is blank or not
 | 
					// PrimaryKeyZero check main primary field's value is blank or not
 | 
				
			||||||
func (scope *Scope) PrimaryKeyZero() bool {
 | 
					func (scope *Scope) PrimaryKeyZero() bool {
 | 
				
			||||||
	field := scope.PrimaryField()
 | 
						field := scope.PrimaryField()
 | 
				
			||||||
	return field == nil || field.IsBlank
 | 
						return field == nil || field.IsBlank
 | 
				
			||||||
@ -170,80 +158,85 @@ func (scope *Scope) HasColumn(column string) bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetColumn to set the column's value
 | 
					// SetColumn to set the column's value, column could be field or field's name/dbname
 | 
				
			||||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
 | 
					func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
 | 
				
			||||||
 | 
						var updateAttrs = map[string]interface{}{}
 | 
				
			||||||
 | 
						if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
				
			||||||
 | 
							updateAttrs = attrs.(map[string]interface{})
 | 
				
			||||||
 | 
							defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if field, ok := column.(*Field); ok {
 | 
						if field, ok := column.(*Field); ok {
 | 
				
			||||||
 | 
							updateAttrs[field.DBName] = value
 | 
				
			||||||
		return field.Set(value)
 | 
							return field.Set(value)
 | 
				
			||||||
	} else if name, ok := column.(string); ok {
 | 
						} else if name, ok := column.(string); ok {
 | 
				
			||||||
 | 
							var (
 | 
				
			||||||
		if field, ok := scope.Fields()[name]; ok {
 | 
								dbName           = ToDBName(name)
 | 
				
			||||||
 | 
								mostMatchedField *Field
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
							for _, field := range scope.Fields() {
 | 
				
			||||||
 | 
								if field.DBName == value {
 | 
				
			||||||
 | 
									updateAttrs[field.DBName] = value
 | 
				
			||||||
				return field.Set(value)
 | 
									return field.Set(value)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
 | 
				
			||||||
		dbName := ToDBName(name)
 | 
									mostMatchedField = field
 | 
				
			||||||
		if field, ok := scope.Fields()[dbName]; ok {
 | 
								}
 | 
				
			||||||
			return field.Set(value)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if field, ok := scope.FieldByName(name); ok {
 | 
							if mostMatchedField != nil {
 | 
				
			||||||
			return field.Set(value)
 | 
								updateAttrs[mostMatchedField.DBName] = value
 | 
				
			||||||
 | 
								return mostMatchedField.Set(value)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return errors.New("could not convert column to field")
 | 
						return errors.New("could not convert column to field")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) CallMethod(name string, checkError bool) {
 | 
					func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
 | 
				
			||||||
	if scope.Value == nil || (checkError && scope.HasError()) {
 | 
						if reflectValue.CanAddr() {
 | 
				
			||||||
 | 
							reflectValue = reflectValue.Addr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
 | 
				
			||||||
 | 
							switch method := methodValue.Interface().(type) {
 | 
				
			||||||
 | 
							case func():
 | 
				
			||||||
 | 
								method()
 | 
				
			||||||
 | 
							case func(*Scope):
 | 
				
			||||||
 | 
								method(scope)
 | 
				
			||||||
 | 
							case func(*DB):
 | 
				
			||||||
 | 
								newDB := scope.NewDB()
 | 
				
			||||||
 | 
								method(newDB)
 | 
				
			||||||
 | 
								scope.Err(newDB.Error)
 | 
				
			||||||
 | 
							case func() error:
 | 
				
			||||||
 | 
								scope.Err(method())
 | 
				
			||||||
 | 
							case func(*Scope) error:
 | 
				
			||||||
 | 
								scope.Err(method(scope))
 | 
				
			||||||
 | 
							case func(*DB) error:
 | 
				
			||||||
 | 
								newDB := scope.NewDB()
 | 
				
			||||||
 | 
								scope.Err(method(newDB))
 | 
				
			||||||
 | 
								scope.Err(newDB.Error)
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								scope.Err(fmt.Errorf("unsupported function %v", methodName))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
 | 
				
			||||||
 | 
					func (scope *Scope) CallMethod(methodName string) {
 | 
				
			||||||
 | 
						if scope.Value == nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	call := func(value interface{}) {
 | 
						if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
 | 
				
			||||||
		if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
 | 
							for i := 0; i < indirectScopeValue.Len(); i++ {
 | 
				
			||||||
			switch f := fm.Interface().(type) {
 | 
								scope.callMethod(methodName, indirectScopeValue.Index(i))
 | 
				
			||||||
			case func():
 | 
					 | 
				
			||||||
				f()
 | 
					 | 
				
			||||||
			case func(s *Scope):
 | 
					 | 
				
			||||||
				f(scope)
 | 
					 | 
				
			||||||
			case func(s *DB):
 | 
					 | 
				
			||||||
				newDB := scope.NewDB()
 | 
					 | 
				
			||||||
				f(newDB)
 | 
					 | 
				
			||||||
				scope.Err(newDB.Error)
 | 
					 | 
				
			||||||
			case func() error:
 | 
					 | 
				
			||||||
				scope.Err(f())
 | 
					 | 
				
			||||||
			case func(s *Scope) error:
 | 
					 | 
				
			||||||
				scope.Err(f(scope))
 | 
					 | 
				
			||||||
			case func(s *DB) error:
 | 
					 | 
				
			||||||
				newDB := scope.NewDB()
 | 
					 | 
				
			||||||
				scope.Err(f(newDB))
 | 
					 | 
				
			||||||
				scope.Err(newDB.Error)
 | 
					 | 
				
			||||||
			default:
 | 
					 | 
				
			||||||
				scope.Err(fmt.Errorf("unsupported function %v", name))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
 | 
					 | 
				
			||||||
		for i := 0; i < values.Len(); i++ {
 | 
					 | 
				
			||||||
			value := values.Index(i).Addr().Interface()
 | 
					 | 
				
			||||||
			if values.Index(i).Kind() == reflect.Ptr {
 | 
					 | 
				
			||||||
				value = values.Index(i).Interface()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			call(value)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		if scope.IndirectValue().CanAddr() {
 | 
							scope.callMethod(methodName, indirectScopeValue)
 | 
				
			||||||
			call(scope.IndirectValue().Addr().Interface())
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			call(scope.IndirectValue().Interface())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) CallMethodWithErrorCheck(name string) {
 | 
					// AddToVars add value as sql's vars, used to prevent SQL injection
 | 
				
			||||||
	scope.CallMethod(name, true)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// AddToVars add value as sql's vars, gorm will escape them
 | 
					 | 
				
			||||||
func (scope *Scope) AddToVars(value interface{}) string {
 | 
					func (scope *Scope) AddToVars(value interface{}) string {
 | 
				
			||||||
	if expr, ok := value.(*expr); ok {
 | 
						if expr, ok := value.(*expr); ok {
 | 
				
			||||||
		exp := expr.expr
 | 
							exp := expr.expr
 | 
				
			||||||
@ -251,10 +244,10 @@ func (scope *Scope) AddToVars(value interface{}) string {
 | 
				
			|||||||
			exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
 | 
								exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return exp
 | 
							return exp
 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		scope.SqlVars = append(scope.SqlVars, value)
 | 
					 | 
				
			||||||
		return scope.Dialect().BinVar(len(scope.SqlVars))
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scope.SQLVars = append(scope.SQLVars, value)
 | 
				
			||||||
 | 
						return scope.Dialect().BindVar(len(scope.SQLVars))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type tabler interface {
 | 
					type tabler interface {
 | 
				
			||||||
@ -265,7 +258,7 @@ type dbTabler interface {
 | 
				
			|||||||
	TableName(*DB) string
 | 
						TableName(*DB) string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TableName get table name
 | 
					// TableName return table name
 | 
				
			||||||
func (scope *Scope) TableName() string {
 | 
					func (scope *Scope) TableName() string {
 | 
				
			||||||
	if scope.Search != nil && len(scope.Search.tableName) > 0 {
 | 
						if scope.Search != nil && len(scope.Search.tableName) > 0 {
 | 
				
			||||||
		return scope.Search.tableName
 | 
							return scope.Search.tableName
 | 
				
			||||||
@ -282,44 +275,54 @@ func (scope *Scope) TableName() string {
 | 
				
			|||||||
	return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
 | 
						return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// QuotedTableName return quoted table name
 | 
				
			||||||
func (scope *Scope) QuotedTableName() (name string) {
 | 
					func (scope *Scope) QuotedTableName() (name string) {
 | 
				
			||||||
	if scope.Search != nil && len(scope.Search.tableName) > 0 {
 | 
						if scope.Search != nil && len(scope.Search.tableName) > 0 {
 | 
				
			||||||
		if strings.Index(scope.Search.tableName, " ") != -1 {
 | 
							if strings.Index(scope.Search.tableName, " ") != -1 {
 | 
				
			||||||
			return scope.Search.tableName
 | 
								return scope.Search.tableName
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return scope.Quote(scope.Search.tableName)
 | 
							return scope.Quote(scope.Search.tableName)
 | 
				
			||||||
	} else {
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return scope.Quote(scope.TableName())
 | 
						return scope.Quote(scope.TableName())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CombinedConditionSql get combined condition sql
 | 
					// CombinedConditionSql return combined condition sql
 | 
				
			||||||
func (scope *Scope) CombinedConditionSql() string {
 | 
					func (scope *Scope) CombinedConditionSql() string {
 | 
				
			||||||
	return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
 | 
						return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
 | 
				
			||||||
		scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
 | 
							scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// FieldByName find `gorm.Field` with field name or db name
 | 
				
			||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
 | 
					func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							dbName           = ToDBName(name)
 | 
				
			||||||
 | 
							mostMatchedField *Field
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, field := range scope.Fields() {
 | 
						for _, field := range scope.Fields() {
 | 
				
			||||||
		if field.Name == name || field.DBName == name {
 | 
							if field.Name == name || field.DBName == name {
 | 
				
			||||||
			return field, true
 | 
								return field, true
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							if field.DBName == dbName {
 | 
				
			||||||
 | 
								mostMatchedField = field
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	return nil, false
 | 
						}
 | 
				
			||||||
 | 
						return mostMatchedField, mostMatchedField != nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Raw set sql
 | 
					// Raw set raw sql
 | 
				
			||||||
func (scope *Scope) Raw(sql string) *Scope {
 | 
					func (scope *Scope) Raw(sql string) *Scope {
 | 
				
			||||||
	scope.Sql = strings.Replace(sql, "$$", "?", -1)
 | 
						scope.SQL = strings.Replace(sql, "$$", "?", -1)
 | 
				
			||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Exec invoke sql
 | 
					// Exec perform generated SQL
 | 
				
			||||||
func (scope *Scope) Exec() *Scope {
 | 
					func (scope *Scope) Exec() *Scope {
 | 
				
			||||||
	defer scope.Trace(NowFunc())
 | 
						defer scope.trace(NowFunc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
 | 
							if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
				
			||||||
			if count, err := result.RowsAffected(); scope.Err(err) == nil {
 | 
								if count, err := result.RowsAffected(); scope.Err(err) == nil {
 | 
				
			||||||
				scope.db.RowsAffected = count
 | 
									scope.db.RowsAffected = count
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -334,37 +337,32 @@ func (scope *Scope) Set(name string, value interface{}) *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get get value by name
 | 
					// Get get setting by name
 | 
				
			||||||
func (scope *Scope) Get(name string) (interface{}, bool) {
 | 
					func (scope *Scope) Get(name string) (interface{}, bool) {
 | 
				
			||||||
	return scope.db.Get(name)
 | 
						return scope.db.Get(name)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// InstanceId get InstanceId for scope
 | 
					// InstanceID get InstanceID for scope
 | 
				
			||||||
func (scope *Scope) InstanceId() string {
 | 
					func (scope *Scope) InstanceID() string {
 | 
				
			||||||
	if scope.instanceId == "" {
 | 
						if scope.instanceID == "" {
 | 
				
			||||||
		scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db)
 | 
							scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return scope.instanceId
 | 
						return scope.instanceID
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
 | 
				
			||||||
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
 | 
					func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
 | 
				
			||||||
	return scope.Set(name+scope.InstanceId(), value)
 | 
						return scope.Set(name+scope.InstanceID(), value)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// InstanceGet get instance setting from current operation
 | 
				
			||||||
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
 | 
					func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
 | 
				
			||||||
	return scope.Get(name + scope.InstanceId())
 | 
						return scope.Get(name + scope.InstanceID())
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Trace print sql log
 | 
					 | 
				
			||||||
func (scope *Scope) Trace(t time.Time) {
 | 
					 | 
				
			||||||
	if len(scope.Sql) > 0 {
 | 
					 | 
				
			||||||
		scope.db.slog(scope.Sql, t, scope.SqlVars...)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Begin start a transaction
 | 
					// Begin start a transaction
 | 
				
			||||||
func (scope *Scope) Begin() *Scope {
 | 
					func (scope *Scope) Begin() *Scope {
 | 
				
			||||||
	if db, ok := scope.SqlDB().(sqlDb); ok {
 | 
						if db, ok := scope.SQLDB().(sqlDb); ok {
 | 
				
			||||||
		if tx, err := db.Begin(); err == nil {
 | 
							if tx, err := db.Begin(); err == nil {
 | 
				
			||||||
			scope.db.db = interface{}(tx).(sqlCommon)
 | 
								scope.db.db = interface{}(tx).(sqlCommon)
 | 
				
			||||||
			scope.InstanceSet("gorm:started_transaction", true)
 | 
								scope.InstanceSet("gorm:started_transaction", true)
 | 
				
			||||||
@ -373,7 +371,7 @@ func (scope *Scope) Begin() *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
 | 
					// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
 | 
				
			||||||
func (scope *Scope) CommitOrRollback() *Scope {
 | 
					func (scope *Scope) CommitOrRollback() *Scope {
 | 
				
			||||||
	if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
 | 
						if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
 | 
				
			||||||
		if db, ok := scope.db.db.(sqlTx); ok {
 | 
							if db, ok := scope.db.db.(sqlTx); ok {
 | 
				
			||||||
@ -388,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SelectAttrs return selected attributes
 | 
				
			||||||
func (scope *Scope) SelectAttrs() []string {
 | 
					func (scope *Scope) SelectAttrs() []string {
 | 
				
			||||||
	if scope.selectAttrs == nil {
 | 
						if scope.selectAttrs == nil {
 | 
				
			||||||
		attrs := []string{}
 | 
							attrs := []string{}
 | 
				
			||||||
@ -407,57 +406,38 @@ func (scope *Scope) SelectAttrs() []string {
 | 
				
			|||||||
	return *scope.selectAttrs
 | 
						return *scope.selectAttrs
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OmitAttrs return omited attributes
 | 
				
			||||||
func (scope *Scope) OmitAttrs() []string {
 | 
					func (scope *Scope) OmitAttrs() []string {
 | 
				
			||||||
	return scope.Search.omits
 | 
						return scope.Search.omits
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) changeableDBColumn(column string) bool {
 | 
					func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
 | 
				
			||||||
	selectAttrs := scope.SelectAttrs()
 | 
						var values = make([]interface{}, len(columns))
 | 
				
			||||||
	omitAttrs := scope.OmitAttrs()
 | 
						var ignored interface{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(selectAttrs) > 0 {
 | 
						for index, column := range columns {
 | 
				
			||||||
		for _, attr := range selectAttrs {
 | 
							if field, ok := fieldsMap[column]; ok {
 | 
				
			||||||
			if column == ToDBName(attr) {
 | 
								if field.Field.Kind() == reflect.Ptr {
 | 
				
			||||||
				return true
 | 
									values[index] = field.Field.Addr().Interface()
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
 | 
				
			||||||
 | 
									reflectValue.Elem().Set(field.Field.Addr())
 | 
				
			||||||
 | 
									values[index] = reflectValue.Interface()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							} else {
 | 
				
			||||||
		return false
 | 
								values[index] = &ignored
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, attr := range omitAttrs {
 | 
					 | 
				
			||||||
		if column == ToDBName(attr) {
 | 
					 | 
				
			||||||
			return false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) changeableField(field *Field) bool {
 | 
					 | 
				
			||||||
	selectAttrs := scope.SelectAttrs()
 | 
					 | 
				
			||||||
	omitAttrs := scope.OmitAttrs()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(selectAttrs) > 0 {
 | 
					 | 
				
			||||||
		for _, attr := range selectAttrs {
 | 
					 | 
				
			||||||
			if field.Name == attr || field.DBName == attr {
 | 
					 | 
				
			||||||
				return true
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, attr := range omitAttrs {
 | 
					 | 
				
			||||||
		if field.Name == attr || field.DBName == attr {
 | 
					 | 
				
			||||||
			return false
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return !field.IsIgnored
 | 
						scope.Err(rows.Scan(values...))
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) shouldSaveAssociations() bool {
 | 
						for index, column := range columns {
 | 
				
			||||||
	saveAssociations, ok := scope.Get("gorm:save_associations")
 | 
							if field, ok := fieldsMap[column]; ok {
 | 
				
			||||||
	if ok && !saveAssociations.(bool) {
 | 
								if field.Field.Kind() != reflect.Ptr {
 | 
				
			||||||
		return false
 | 
									if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
 | 
				
			||||||
 | 
										field.Field.Set(v)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return true && !scope.HasError()
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										278
									
								
								scope_private.go
									
									
									
									
									
								
							
							
						
						
									
										278
									
								
								scope_private.go
									
									
									
									
									
								
							@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) primaryCondition(value interface{}) string {
 | 
					func (scope *Scope) primaryCondition(value interface{}) string {
 | 
				
			||||||
@ -75,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
 | 
					func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
 | 
				
			||||||
	var notEqualSql string
 | 
						var notEqualSQL string
 | 
				
			||||||
	var primaryKey = scope.PrimaryKey()
 | 
						var primaryKey = scope.PrimaryKey()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	switch value := clause["query"].(type) {
 | 
						switch value := clause["query"].(type) {
 | 
				
			||||||
@ -86,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
 | 
				
			|||||||
			return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
 | 
								return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
 | 
				
			||||||
		} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
 | 
							} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
 | 
				
			||||||
			str = fmt.Sprintf(" NOT (%v) ", value)
 | 
								str = fmt.Sprintf(" NOT (%v) ", value)
 | 
				
			||||||
			notEqualSql = fmt.Sprintf("NOT (%v)", value)
 | 
								notEqualSQL = fmt.Sprintf("NOT (%v)", value)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
 | 
								str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
 | 
				
			||||||
			notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
 | 
								notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
 | 
						case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
 | 
				
			||||||
		return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
 | 
							return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
 | 
				
			||||||
@ -138,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
 | 
				
			|||||||
			if scanner, ok := interface{}(arg).(driver.Valuer); ok {
 | 
								if scanner, ok := interface{}(arg).(driver.Valuer); ok {
 | 
				
			||||||
				arg, _ = scanner.Value()
 | 
									arg, _ = scanner.Value()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
 | 
								str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
@ -172,17 +173,20 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) whereSql() (sql string) {
 | 
					func (scope *Scope) whereSQL() (sql string) {
 | 
				
			||||||
	var primaryConditions, andConditions, orConditions []string
 | 
						var (
 | 
				
			||||||
 | 
							quotedTableName                                = scope.QuotedTableName()
 | 
				
			||||||
 | 
							primaryConditions, andConditions, orConditions []string
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
 | 
						if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
 | 
				
			||||||
		sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
 | 
							sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
 | 
				
			||||||
		primaryConditions = append(primaryConditions, sql)
 | 
							primaryConditions = append(primaryConditions, sql)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.PrimaryKeyZero() {
 | 
						if !scope.PrimaryKeyZero() {
 | 
				
			||||||
		for _, field := range scope.PrimaryFields() {
 | 
							for _, field := range scope.PrimaryFields() {
 | 
				
			||||||
			sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
 | 
								sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
 | 
				
			||||||
			primaryConditions = append(primaryConditions, sql)
 | 
								primaryConditions = append(primaryConditions, sql)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -205,30 +209,30 @@ func (scope *Scope) whereSql() (sql string) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	orSql := strings.Join(orConditions, " OR ")
 | 
						orSQL := strings.Join(orConditions, " OR ")
 | 
				
			||||||
	combinedSql := strings.Join(andConditions, " AND ")
 | 
						combinedSQL := strings.Join(andConditions, " AND ")
 | 
				
			||||||
	if len(combinedSql) > 0 {
 | 
						if len(combinedSQL) > 0 {
 | 
				
			||||||
		if len(orSql) > 0 {
 | 
							if len(orSQL) > 0 {
 | 
				
			||||||
			combinedSql = combinedSql + " OR " + orSql
 | 
								combinedSQL = combinedSQL + " OR " + orSQL
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		combinedSql = orSql
 | 
							combinedSQL = orSQL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(primaryConditions) > 0 {
 | 
						if len(primaryConditions) > 0 {
 | 
				
			||||||
		sql = "WHERE " + strings.Join(primaryConditions, " AND ")
 | 
							sql = "WHERE " + strings.Join(primaryConditions, " AND ")
 | 
				
			||||||
		if len(combinedSql) > 0 {
 | 
							if len(combinedSQL) > 0 {
 | 
				
			||||||
			sql = sql + " AND (" + combinedSql + ")"
 | 
								sql = sql + " AND (" + combinedSQL + ")"
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else if len(combinedSql) > 0 {
 | 
						} else if len(combinedSQL) > 0 {
 | 
				
			||||||
		sql = "WHERE " + combinedSql
 | 
							sql = "WHERE " + combinedSQL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) selectSql() string {
 | 
					func (scope *Scope) selectSQL() string {
 | 
				
			||||||
	if len(scope.Search.selects) == 0 {
 | 
						if len(scope.Search.selects) == 0 {
 | 
				
			||||||
		if scope.Search.joins != "" {
 | 
							if len(scope.Search.joinConditions) > 0 {
 | 
				
			||||||
			return fmt.Sprintf("%v.*", scope.QuotedTableName())
 | 
								return fmt.Sprintf("%v.*", scope.QuotedTableName())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return "*"
 | 
							return "*"
 | 
				
			||||||
@ -236,87 +240,60 @@ func (scope *Scope) selectSql() string {
 | 
				
			|||||||
	return scope.buildSelectQuery(scope.Search.selects)
 | 
						return scope.buildSelectQuery(scope.Search.selects)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) orderSql() string {
 | 
					func (scope *Scope) orderSQL() string {
 | 
				
			||||||
	if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
 | 
						if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return " ORDER BY " + strings.Join(scope.Search.orders, ",")
 | 
						return " ORDER BY " + strings.Join(scope.Search.orders, ",")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) limitSql() string {
 | 
					func (scope *Scope) limitAndOffsetSQL() string {
 | 
				
			||||||
	if !scope.Dialect().HasTop() {
 | 
						return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
 | 
				
			||||||
		if len(scope.Search.limit) == 0 {
 | 
					 | 
				
			||||||
			return ""
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return " LIMIT " + scope.Search.limit
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return ""
 | 
					func (scope *Scope) groupSQL() string {
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) topSql() string {
 | 
					 | 
				
			||||||
	if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
 | 
					 | 
				
			||||||
		if len(scope.Search.limit) == 0 {
 | 
					 | 
				
			||||||
			return ""
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return " TOP(" + scope.Search.limit + ")"
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) offsetSql() string {
 | 
					 | 
				
			||||||
	if len(scope.Search.offset) == 0 {
 | 
					 | 
				
			||||||
		return ""
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if scope.Dialect().HasTop() {
 | 
					 | 
				
			||||||
		sql := " OFFSET " + scope.Search.offset + " ROW "
 | 
					 | 
				
			||||||
		if len(scope.Search.limit) > 0 {
 | 
					 | 
				
			||||||
			sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return sql
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return " OFFSET " + scope.Search.offset
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) groupSql() string {
 | 
					 | 
				
			||||||
	if len(scope.Search.group) == 0 {
 | 
						if len(scope.Search.group) == 0 {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return " GROUP BY " + scope.Search.group
 | 
						return " GROUP BY " + scope.Search.group
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) havingSql() string {
 | 
					func (scope *Scope) havingSQL() string {
 | 
				
			||||||
	if scope.Search.havingConditions == nil {
 | 
						if len(scope.Search.havingConditions) == 0 {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var andConditions []string
 | 
						var andConditions []string
 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, clause := range scope.Search.havingConditions {
 | 
						for _, clause := range scope.Search.havingConditions {
 | 
				
			||||||
		if sql := scope.buildWhereCondition(clause); sql != "" {
 | 
							if sql := scope.buildWhereCondition(clause); sql != "" {
 | 
				
			||||||
			andConditions = append(andConditions, sql)
 | 
								andConditions = append(andConditions, sql)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	combinedSql := strings.Join(andConditions, " AND ")
 | 
						combinedSQL := strings.Join(andConditions, " AND ")
 | 
				
			||||||
	if len(combinedSql) == 0 {
 | 
						if len(combinedSQL) == 0 {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return " HAVING " + combinedSql
 | 
						return " HAVING " + combinedSQL
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) joinsSql() string {
 | 
					func (scope *Scope) joinsSQL() string {
 | 
				
			||||||
	return scope.Search.joins + " "
 | 
						var joinConditions []string
 | 
				
			||||||
 | 
						for _, clause := range scope.Search.joinConditions {
 | 
				
			||||||
 | 
							if sql := scope.buildWhereCondition(clause); sql != "" {
 | 
				
			||||||
 | 
								joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) prepareQuerySql() {
 | 
						return strings.Join(joinConditions, " ") + " "
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) prepareQuerySQL() {
 | 
				
			||||||
	if scope.Search.raw {
 | 
						if scope.Search.raw {
 | 
				
			||||||
		scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
 | 
							scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
 | 
							scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -338,61 +315,53 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
 | 
					func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
 | 
				
			||||||
	if !scope.IndirectValue().CanAddr() {
 | 
						if scope.IndirectValue().Kind() != reflect.Struct {
 | 
				
			||||||
		return values, true
 | 
							return values, true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var hasExpr bool
 | 
						results = map[string]interface{}{}
 | 
				
			||||||
	for key, value := range values {
 | 
						for key, value := range values {
 | 
				
			||||||
		if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
 | 
							if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
 | 
				
			||||||
			if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
 | 
								if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
 | 
				
			||||||
				if _, ok := value.(*expr); ok {
 | 
									if _, ok := value.(*expr); ok {
 | 
				
			||||||
					hasExpr = true
 | 
					 | 
				
			||||||
				} else if !equalAsString(field.Field.Interface(), value) {
 | 
					 | 
				
			||||||
					hasUpdate = true
 | 
										hasUpdate = true
 | 
				
			||||||
 | 
										results[field.DBName] = value
 | 
				
			||||||
 | 
									} else if !equalAsString(field.Field.Interface(), value) {
 | 
				
			||||||
 | 
										field.Set(value)
 | 
				
			||||||
 | 
										if field.IsNormal {
 | 
				
			||||||
 | 
											hasUpdate = true
 | 
				
			||||||
 | 
											results[field.DBName] = field.Field.Interface()
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
				field.Set(value)
 | 
									field.Set(value)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if hasExpr {
 | 
					 | 
				
			||||||
		var updateMap = map[string]interface{}{}
 | 
					 | 
				
			||||||
		for key, field := range scope.Fields() {
 | 
					 | 
				
			||||||
			if field.IsNormal {
 | 
					 | 
				
			||||||
				if v, ok := values[key]; ok {
 | 
					 | 
				
			||||||
					updateMap[key] = v
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					updateMap[key] = field.Field.Interface()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return updateMap, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) row() *sql.Row {
 | 
					func (scope *Scope) row() *sql.Row {
 | 
				
			||||||
	defer scope.Trace(NowFunc())
 | 
						defer scope.trace(NowFunc())
 | 
				
			||||||
	scope.callCallbacks(scope.db.parent.callback.rowQueries)
 | 
						scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
 | 
				
			||||||
	scope.prepareQuerySql()
 | 
						scope.prepareQuerySQL()
 | 
				
			||||||
	return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
 | 
						return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) rows() (*sql.Rows, error) {
 | 
					func (scope *Scope) rows() (*sql.Rows, error) {
 | 
				
			||||||
	defer scope.Trace(NowFunc())
 | 
						defer scope.trace(NowFunc())
 | 
				
			||||||
	scope.callCallbacks(scope.db.parent.callback.rowQueries)
 | 
						scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
 | 
				
			||||||
	scope.prepareQuerySql()
 | 
						scope.prepareQuerySQL()
 | 
				
			||||||
	return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
 | 
						return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) initialize() *Scope {
 | 
					func (scope *Scope) initialize() *Scope {
 | 
				
			||||||
	for _, clause := range scope.Search.whereConditions {
 | 
						for _, clause := range scope.Search.whereConditions {
 | 
				
			||||||
		scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
 | 
							scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
 | 
						scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
 | 
				
			||||||
	scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
 | 
						scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
 | 
				
			||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -433,23 +402,45 @@ func (scope *Scope) typeName() string {
 | 
				
			|||||||
	return typ.Name()
 | 
						return typ.Name()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// trace print sql log
 | 
				
			||||||
 | 
					func (scope *Scope) trace(t time.Time) {
 | 
				
			||||||
 | 
						if len(scope.SQL) > 0 {
 | 
				
			||||||
 | 
							scope.db.slog(scope.SQL, t, scope.SQLVars...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) changeableField(field *Field) bool {
 | 
				
			||||||
 | 
						if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
 | 
				
			||||||
 | 
							for _, attr := range selectAttrs {
 | 
				
			||||||
 | 
								if field.Name == attr || field.DBName == attr {
 | 
				
			||||||
 | 
									return true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, attr := range scope.OmitAttrs() {
 | 
				
			||||||
 | 
							if field.Name == attr || field.DBName == attr {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) shouldSaveAssociations() bool {
 | 
				
			||||||
 | 
						if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return true && !scope.HasError()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 | 
					func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 | 
				
			||||||
	toScope := scope.db.NewScope(value)
 | 
						toScope := scope.db.NewScope(value)
 | 
				
			||||||
	fromFields := scope.Fields()
 | 
					 | 
				
			||||||
	toFields := toScope.Fields()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
 | 
						for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
 | 
				
			||||||
		var fromField, toField *Field
 | 
							fromField, _ := scope.FieldByName(foreignKey)
 | 
				
			||||||
		if field, ok := scope.FieldByName(foreignKey); ok {
 | 
							toField, _ := toScope.FieldByName(foreignKey)
 | 
				
			||||||
			fromField = field
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			fromField = fromFields[ToDBName(foreignKey)]
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if field, ok := toScope.FieldByName(foreignKey); ok {
 | 
					 | 
				
			||||||
			toField = field
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			toField = toFields[ToDBName(foreignKey)]
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if fromField != nil {
 | 
							if fromField != nil {
 | 
				
			||||||
			if relationship := fromField.Relationship; relationship != nil {
 | 
								if relationship := fromField.Relationship; relationship != nil {
 | 
				
			||||||
@ -508,30 +499,26 @@ func (scope *Scope) createJoinTable(field *StructField) {
 | 
				
			|||||||
	if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
 | 
						if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
 | 
				
			||||||
		joinTableHandler := relationship.JoinTableHandler
 | 
							joinTableHandler := relationship.JoinTableHandler
 | 
				
			||||||
		joinTable := joinTableHandler.Table(scope.db)
 | 
							joinTable := joinTableHandler.Table(scope.db)
 | 
				
			||||||
		if !scope.Dialect().HasTable(scope, joinTable) {
 | 
							if !scope.Dialect().HasTable(joinTable) {
 | 
				
			||||||
			toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
 | 
								toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var sqlTypes, primaryKeys []string
 | 
								var sqlTypes, primaryKeys []string
 | 
				
			||||||
			for idx, fieldName := range relationship.ForeignFieldNames {
 | 
								for idx, fieldName := range relationship.ForeignFieldNames {
 | 
				
			||||||
				if field, ok := scope.Fields()[fieldName]; ok {
 | 
									if field, ok := scope.FieldByName(fieldName); ok {
 | 
				
			||||||
					value := reflect.Indirect(reflect.New(field.Struct.Type))
 | 
										foreignKeyStruct := field.clone()
 | 
				
			||||||
					primaryKeySqlType := field.TagSettings["TYPE"]
 | 
										foreignKeyStruct.IsPrimaryKey = false
 | 
				
			||||||
					if primaryKeySqlType == "" {
 | 
										foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
 | 
				
			||||||
						primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
 | 
										sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
 | 
					 | 
				
			||||||
					primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
 | 
										primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for idx, fieldName := range relationship.AssociationForeignFieldNames {
 | 
								for idx, fieldName := range relationship.AssociationForeignFieldNames {
 | 
				
			||||||
				if field, ok := toScope.Fields()[fieldName]; ok {
 | 
									if field, ok := toScope.FieldByName(fieldName); ok {
 | 
				
			||||||
					value := reflect.Indirect(reflect.New(field.Struct.Type))
 | 
										foreignKeyStruct := field.clone()
 | 
				
			||||||
					primaryKeySqlType := field.TagSettings["TYPE"]
 | 
										foreignKeyStruct.IsPrimaryKey = false
 | 
				
			||||||
					if primaryKeySqlType == "" {
 | 
										foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
 | 
				
			||||||
						primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
 | 
										sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
 | 
					 | 
				
			||||||
					primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
 | 
										primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -545,10 +532,10 @@ func (scope *Scope) createJoinTable(field *StructField) {
 | 
				
			|||||||
func (scope *Scope) createTable() *Scope {
 | 
					func (scope *Scope) createTable() *Scope {
 | 
				
			||||||
	var tags []string
 | 
						var tags []string
 | 
				
			||||||
	var primaryKeys []string
 | 
						var primaryKeys []string
 | 
				
			||||||
	var primaryKeyInColumnType bool = false
 | 
						var primaryKeyInColumnType = false
 | 
				
			||||||
	for _, field := range scope.GetStructFields() {
 | 
						for _, field := range scope.GetModelStruct().StructFields {
 | 
				
			||||||
		if field.IsNormal {
 | 
							if field.IsNormal {
 | 
				
			||||||
			sqlTag := scope.generateSqlTag(field)
 | 
								sqlTag := scope.Dialect().DataTypeOf(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Check if the primary key constraint was specified as
 | 
								// Check if the primary key constraint was specified as
 | 
				
			||||||
			// part of the column type. If so, we can only support
 | 
								// part of the column type. If so, we can only support
 | 
				
			||||||
@ -582,13 +569,6 @@ func (scope *Scope) dropTable() *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) dropTableIfExists() *Scope {
 | 
					 | 
				
			||||||
	if scope.Dialect().HasTable(scope, scope.TableName()) {
 | 
					 | 
				
			||||||
		scope.dropTable()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return scope
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (scope *Scope) modifyColumn(column string, typ string) {
 | 
					func (scope *Scope) modifyColumn(column string, typ string) {
 | 
				
			||||||
	scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
 | 
						scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -598,13 +578,13 @@ func (scope *Scope) dropColumn(column string) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
 | 
					func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
 | 
				
			||||||
	if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
 | 
						if scope.Dialect().HasIndex(scope.TableName(), indexName) {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var columns []string
 | 
						var columns []string
 | 
				
			||||||
	for _, name := range column {
 | 
						for _, name := range column {
 | 
				
			||||||
		columns = append(columns, scope.QuoteIfPossible(name))
 | 
							columns = append(columns, scope.quoteIfPossible(name))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sqlCreate := "CREATE INDEX"
 | 
						sqlCreate := "CREATE INDEX"
 | 
				
			||||||
@ -612,31 +592,35 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
 | 
				
			|||||||
		sqlCreate = "CREATE UNIQUE INDEX"
 | 
							sqlCreate = "CREATE UNIQUE INDEX"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
 | 
						scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
 | 
					func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
 | 
				
			||||||
	var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
 | 
						var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
 | 
				
			||||||
	keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
 | 
						keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
 | 
						var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
 | 
				
			||||||
	scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
 | 
						scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) removeIndex(indexName string) {
 | 
					func (scope *Scope) removeIndex(indexName string) {
 | 
				
			||||||
	scope.Dialect().RemoveIndex(scope, indexName)
 | 
						scope.Dialect().RemoveIndex(scope.TableName(), indexName)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) autoMigrate() *Scope {
 | 
					func (scope *Scope) autoMigrate() *Scope {
 | 
				
			||||||
	tableName := scope.TableName()
 | 
						tableName := scope.TableName()
 | 
				
			||||||
	quotedTableName := scope.QuotedTableName()
 | 
						quotedTableName := scope.QuotedTableName()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.Dialect().HasTable(scope, tableName) {
 | 
						if !scope.Dialect().HasTable(tableName) {
 | 
				
			||||||
		scope.createTable()
 | 
							scope.createTable()
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		for _, field := range scope.GetStructFields() {
 | 
							for _, field := range scope.GetModelStruct().StructFields {
 | 
				
			||||||
			if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
 | 
								if !scope.Dialect().HasColumn(tableName, field.DBName) {
 | 
				
			||||||
				if field.IsNormal {
 | 
									if field.IsNormal {
 | 
				
			||||||
					sqlTag := scope.generateSqlTag(field)
 | 
										sqlTag := scope.Dialect().DataTypeOf(field)
 | 
				
			||||||
					scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
 | 
										scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										67
									
								
								scope_utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								scope_utils.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,67 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "reflect"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
 | 
				
			||||||
 | 
						for _, value := range values {
 | 
				
			||||||
 | 
							indirectValue := reflect.ValueOf(value)
 | 
				
			||||||
 | 
							for indirectValue.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
								indirectValue = indirectValue.Elem()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							switch indirectValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Slice:
 | 
				
			||||||
 | 
								for i := 0; i < indirectValue.Len(); i++ {
 | 
				
			||||||
 | 
									var result []interface{}
 | 
				
			||||||
 | 
									var object = indirect(indirectValue.Index(i))
 | 
				
			||||||
 | 
									for _, column := range columns {
 | 
				
			||||||
 | 
										result = append(result, object.FieldByName(column).Interface())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									results = append(results, result)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Struct:
 | 
				
			||||||
 | 
								var result []interface{}
 | 
				
			||||||
 | 
								for _, column := range columns {
 | 
				
			||||||
 | 
									result = append(result, indirectValue.FieldByName(column).Interface())
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								results = append(results, result)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) getColumnAsScope(column string) *Scope {
 | 
				
			||||||
 | 
						indirectScopeValue := scope.IndirectValue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch indirectScopeValue.Kind() {
 | 
				
			||||||
 | 
						case reflect.Slice:
 | 
				
			||||||
 | 
							if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
 | 
				
			||||||
 | 
								fieldType := fieldStruct.Type
 | 
				
			||||||
 | 
								if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
									fieldType = fieldType.Elem()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for i := 0; i < indirectScopeValue.Len(); i++ {
 | 
				
			||||||
 | 
									result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if result.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
										for j := 0; j < result.Len(); j++ {
 | 
				
			||||||
 | 
											if elem := result.Index(j); elem.CanAddr() {
 | 
				
			||||||
 | 
												results = reflect.Append(results, elem.Addr())
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									} else if result.CanAddr() {
 | 
				
			||||||
 | 
										results = reflect.Append(results, result.Addr())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return scope.New(results.Interface())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case reflect.Struct:
 | 
				
			||||||
 | 
							if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
 | 
				
			||||||
 | 
								return scope.New(field.Addr().Interface())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										24
									
								
								search.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								search.go
									
									
									
									
									
								
							@ -8,15 +8,15 @@ type search struct {
 | 
				
			|||||||
	orConditions     []map[string]interface{}
 | 
						orConditions     []map[string]interface{}
 | 
				
			||||||
	notConditions    []map[string]interface{}
 | 
						notConditions    []map[string]interface{}
 | 
				
			||||||
	havingConditions []map[string]interface{}
 | 
						havingConditions []map[string]interface{}
 | 
				
			||||||
 | 
						joinConditions   []map[string]interface{}
 | 
				
			||||||
	initAttrs        []interface{}
 | 
						initAttrs        []interface{}
 | 
				
			||||||
	assignAttrs      []interface{}
 | 
						assignAttrs      []interface{}
 | 
				
			||||||
	selects          map[string]interface{}
 | 
						selects          map[string]interface{}
 | 
				
			||||||
	omits            []string
 | 
						omits            []string
 | 
				
			||||||
	orders           []string
 | 
						orders           []string
 | 
				
			||||||
	joins            string
 | 
					 | 
				
			||||||
	preload          []searchPreload
 | 
						preload          []searchPreload
 | 
				
			||||||
	offset           string
 | 
						offset           int
 | 
				
			||||||
	limit            string
 | 
						limit            int
 | 
				
			||||||
	group            string
 | 
						group            string
 | 
				
			||||||
	tableName        string
 | 
						tableName        string
 | 
				
			||||||
	raw              bool
 | 
						raw              bool
 | 
				
			||||||
@ -82,18 +82,18 @@ func (s *search) Omit(columns ...string) *search {
 | 
				
			|||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *search) Limit(value interface{}) *search {
 | 
					func (s *search) Limit(limit int) *search {
 | 
				
			||||||
	s.limit = s.getInterfaceAsSql(value)
 | 
						s.limit = limit
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *search) Offset(value interface{}) *search {
 | 
					func (s *search) Offset(offset int) *search {
 | 
				
			||||||
	s.offset = s.getInterfaceAsSql(value)
 | 
						s.offset = offset
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *search) Group(query string) *search {
 | 
					func (s *search) Group(query string) *search {
 | 
				
			||||||
	s.group = s.getInterfaceAsSql(query)
 | 
						s.group = s.getInterfaceAsSQL(query)
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -102,8 +102,8 @@ func (s *search) Having(query string, values ...interface{}) *search {
 | 
				
			|||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *search) Joins(query string) *search {
 | 
					func (s *search) Joins(query string, values ...interface{}) *search {
 | 
				
			||||||
	s.joins = query
 | 
						s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
 | 
				
			|||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
 | 
					func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
 | 
				
			||||||
	switch value.(type) {
 | 
						switch value.(type) {
 | 
				
			||||||
	case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
 | 
						case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
 | 
				
			||||||
		str = fmt.Sprintf("%v", value)
 | 
							str = fmt.Sprintf("%v", value)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		s.db.AddError(InvalidSql)
 | 
							s.db.AddError(ErrInvalidSQL)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if str == "-1" {
 | 
						if str == "-1" {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										84
									
								
								sqlite3.go
									
									
									
									
									
								
							
							
						
						
									
										84
									
								
								sqlite3.go
									
									
									
									
									
								
							@ -1,84 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type sqlite3 struct {
 | 
					 | 
				
			||||||
	commonDialect
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
					 | 
				
			||||||
	switch value.Kind() {
 | 
					 | 
				
			||||||
	case reflect.Bool:
 | 
					 | 
				
			||||||
		return "bool"
 | 
					 | 
				
			||||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "integer primary key autoincrement"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "integer"
 | 
					 | 
				
			||||||
	case reflect.Int64, reflect.Uint64:
 | 
					 | 
				
			||||||
		if autoIncrease {
 | 
					 | 
				
			||||||
			return "integer primary key autoincrement"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "bigint"
 | 
					 | 
				
			||||||
	case reflect.Float32, reflect.Float64:
 | 
					 | 
				
			||||||
		return "real"
 | 
					 | 
				
			||||||
	case reflect.String:
 | 
					 | 
				
			||||||
		if size > 0 && size < 65532 {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("varchar(%d)", size)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return "text"
 | 
					 | 
				
			||||||
	case reflect.Struct:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().(time.Time); ok {
 | 
					 | 
				
			||||||
			return "datetime"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		if _, ok := value.Interface().([]byte); ok {
 | 
					 | 
				
			||||||
			return "blob"
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
 | 
					 | 
				
			||||||
	var count int
 | 
					 | 
				
			||||||
	s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
 | 
					 | 
				
			||||||
	return count > 0
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
 | 
					 | 
				
			||||||
	scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		ifaces   = make([]interface{}, 3)
 | 
					 | 
				
			||||||
		pointers = make([]*string, 3)
 | 
					 | 
				
			||||||
		i        int
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	for i = 0; i < 3; i++ {
 | 
					 | 
				
			||||||
		ifaces[i] = &pointers[i]
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if pointers[1] != nil {
 | 
					 | 
				
			||||||
		name = *pointers[1]
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -42,9 +42,9 @@ type CreditCard struct {
 | 
				
			|||||||
	ID        int8
 | 
						ID        int8
 | 
				
			||||||
	Number    string
 | 
						Number    string
 | 
				
			||||||
	UserId    sql.NullInt64
 | 
						UserId    sql.NullInt64
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time `sql:"not null"`
 | 
				
			||||||
	UpdatedAt time.Time
 | 
						UpdatedAt time.Time
 | 
				
			||||||
	DeletedAt time.Time
 | 
						DeletedAt *time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Email struct {
 | 
					type Email struct {
 | 
				
			||||||
@ -62,7 +62,7 @@ type Address struct {
 | 
				
			|||||||
	Post      string
 | 
						Post      string
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time
 | 
				
			||||||
	UpdatedAt time.Time
 | 
						UpdatedAt time.Time
 | 
				
			||||||
	DeletedAt time.Time
 | 
						DeletedAt *time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Language struct {
 | 
					type Language struct {
 | 
				
			||||||
 | 
				
			|||||||
@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DB.First(&product4, product4.Id)
 | 
						DB.First(&product4, product4.Id)
 | 
				
			||||||
 | 
						updatedAt4 := product4.UpdatedAt
 | 
				
			||||||
	DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
 | 
						DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
 | 
				
			||||||
	var product5 Product
 | 
						var product5 Product
 | 
				
			||||||
	DB.First(&product5, product4.Id)
 | 
						DB.First(&product5, product4.Id)
 | 
				
			||||||
	if product5.Price != product4.Price+100-50 {
 | 
						if product5.Price != product4.Price+100-50 {
 | 
				
			||||||
		t.Errorf("Update with expression")
 | 
							t.Errorf("Update with expression")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
						if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
 | 
				
			||||||
		t.Errorf("Update with expression should update UpdatedAt")
 | 
							t.Errorf("Update with expression should update UpdatedAt")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("product2's code should be updated")
 | 
							t.Errorf("product2's code should be updated")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						updatedAt4 := product4.UpdatedAt
 | 
				
			||||||
	DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
 | 
						DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
 | 
				
			||||||
	var product5 Product
 | 
						var product5 Product
 | 
				
			||||||
	DB.First(&product5, product4.Id)
 | 
						DB.First(&product5, product4.Id)
 | 
				
			||||||
	if product5.Price != product4.Price+100 {
 | 
						if product5.Price != product4.Price+100 {
 | 
				
			||||||
		t.Errorf("Updates with expression")
 | 
							t.Errorf("Updates with expression")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
						// product4's UpdatedAt will be reset when updating
 | 
				
			||||||
 | 
						if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
 | 
				
			||||||
		t.Errorf("Updates with expression should update UpdatedAt")
 | 
							t.Errorf("Updates with expression should update UpdatedAt")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -419,3 +422,32 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
 | 
							t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUpdatesWithBlankValues(t *testing.T) {
 | 
				
			||||||
 | 
						product := Product{Code: "product1", Price: 10}
 | 
				
			||||||
 | 
						DB.Save(&product)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var product1 Product
 | 
				
			||||||
 | 
						DB.First(&product1, product.Id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if product1.Code != "product1" || product1.Price != 100 {
 | 
				
			||||||
 | 
							t.Errorf("product's code should not be updated")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUpdateDecodeVirtualAttributes(t *testing.T) {
 | 
				
			||||||
 | 
						var user = User{
 | 
				
			||||||
 | 
							Name:     "jinzhu",
 | 
				
			||||||
 | 
							IgnoreMe: 88,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.Save(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if user.IgnoreMe != 100 {
 | 
				
			||||||
 | 
							t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										233
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										233
									
								
								utils.go
									
									
									
									
									
								
							@ -2,10 +2,26 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"database/sql/driver"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NowFunc returns current time, this function is exported in order to be able
 | 
				
			||||||
 | 
					// to give the flexibility to the developer to customize it according to their
 | 
				
			||||||
 | 
					// needs, e.g:
 | 
				
			||||||
 | 
					//    gorm.NowFunc = func() time.Time {
 | 
				
			||||||
 | 
					//      return time.Now().UTC()
 | 
				
			||||||
 | 
					//    }
 | 
				
			||||||
 | 
					var NowFunc = func() time.Time {
 | 
				
			||||||
 | 
						return time.Now()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Copied from golint
 | 
					// Copied from golint
 | 
				
			||||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
 | 
					var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
 | 
				
			||||||
var commonInitialismsReplacer *strings.Replacer
 | 
					var commonInitialismsReplacer *strings.Replacer
 | 
				
			||||||
@ -41,30 +57,239 @@ func newSafeMap() *safeMap {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var smap = newSafeMap()
 | 
					var smap = newSafeMap()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type strCase bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						lower strCase = false
 | 
				
			||||||
 | 
						upper strCase = true
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ToDBName convert string to db name
 | 
				
			||||||
func ToDBName(name string) string {
 | 
					func ToDBName(name string) string {
 | 
				
			||||||
	if v := smap.Get(name); v != "" {
 | 
						if v := smap.Get(name); v != "" {
 | 
				
			||||||
		return v
 | 
							return v
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	value := commonInitialismsReplacer.Replace(name)
 | 
						if name == "" {
 | 
				
			||||||
	buf := bytes.NewBufferString("")
 | 
							return ""
 | 
				
			||||||
	for i, v := range value {
 | 
						}
 | 
				
			||||||
		if i > 0 && v >= 'A' && v <= 'Z' {
 | 
					
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							value                        = commonInitialismsReplacer.Replace(name)
 | 
				
			||||||
 | 
							buf                          = bytes.NewBufferString("")
 | 
				
			||||||
 | 
							lastCase, currCase, nextCase strCase
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i, v := range value[:len(value)-1] {
 | 
				
			||||||
 | 
							nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z'
 | 
				
			||||||
 | 
							if i > 0 {
 | 
				
			||||||
 | 
								if currCase == upper {
 | 
				
			||||||
 | 
									if lastCase == upper && nextCase == upper {
 | 
				
			||||||
 | 
										buf.WriteRune(v)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										if value[i-1] != '_' && value[i+1] != '_' {
 | 
				
			||||||
						buf.WriteRune('_')
 | 
											buf.WriteRune('_')
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
					buf.WriteRune(v)
 | 
										buf.WriteRune(v)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									buf.WriteRune(v)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								currCase = upper
 | 
				
			||||||
 | 
								buf.WriteRune(v)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							lastCase = currCase
 | 
				
			||||||
 | 
							currCase = nextCase
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						buf.WriteByte(value[len(value)-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s := strings.ToLower(buf.String())
 | 
						s := strings.ToLower(buf.String())
 | 
				
			||||||
	smap.Set(name, s)
 | 
						smap.Set(name, s)
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SQL expression
 | 
				
			||||||
type expr struct {
 | 
					type expr struct {
 | 
				
			||||||
	expr string
 | 
						expr string
 | 
				
			||||||
	args []interface{}
 | 
						args []interface{}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Expr generate raw SQL expression, for example:
 | 
				
			||||||
 | 
					//     DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
 | 
				
			||||||
func Expr(expression string, args ...interface{}) *expr {
 | 
					func Expr(expression string, args ...interface{}) *expr {
 | 
				
			||||||
	return &expr{expr: expression, args: args}
 | 
						return &expr{expr: expression, args: args}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func indirect(reflectValue reflect.Value) reflect.Value {
 | 
				
			||||||
 | 
						for reflectValue.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
							reflectValue = reflectValue.Elem()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return reflectValue
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func toQueryMarks(primaryValues [][]interface{}) string {
 | 
				
			||||||
 | 
						var results []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, primaryValue := range primaryValues {
 | 
				
			||||||
 | 
							var marks []string
 | 
				
			||||||
 | 
							for _ = range primaryValue {
 | 
				
			||||||
 | 
								marks = append(marks, "?")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(marks) > 1 {
 | 
				
			||||||
 | 
								results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								results = append(results, strings.Join(marks, ""))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return strings.Join(results, ",")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func toQueryCondition(scope *Scope, columns []string) string {
 | 
				
			||||||
 | 
						var newColumns []string
 | 
				
			||||||
 | 
						for _, column := range columns {
 | 
				
			||||||
 | 
							newColumns = append(newColumns, scope.Quote(column))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(columns) > 1 {
 | 
				
			||||||
 | 
							return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return strings.Join(newColumns, ",")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func toQueryValues(values [][]interface{}) (results []interface{}) {
 | 
				
			||||||
 | 
						for _, value := range values {
 | 
				
			||||||
 | 
							for _, v := range value {
 | 
				
			||||||
 | 
								results = append(results, v)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func fileWithLineNum() string {
 | 
				
			||||||
 | 
						for i := 2; i < 15; i++ {
 | 
				
			||||||
 | 
							_, file, line, ok := runtime.Caller(i)
 | 
				
			||||||
 | 
							if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
 | 
				
			||||||
 | 
								return fmt.Sprintf("%v:%v", file, line)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func isBlank(value reflect.Value) bool {
 | 
				
			||||||
 | 
						return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func toSearchableMap(attrs ...interface{}) (result interface{}) {
 | 
				
			||||||
 | 
						if len(attrs) > 1 {
 | 
				
			||||||
 | 
							if str, ok := attrs[0].(string); ok {
 | 
				
			||||||
 | 
								result = map[string]interface{}{str: attrs[1]}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if len(attrs) == 1 {
 | 
				
			||||||
 | 
							if attr, ok := attrs[0].(map[string]interface{}); ok {
 | 
				
			||||||
 | 
								result = attr
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if attr, ok := attrs[0].(interface{}); ok {
 | 
				
			||||||
 | 
								result = attr
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func convertInterfaceToMap(values interface{}) map[string]interface{} {
 | 
				
			||||||
 | 
						attrs := map[string]interface{}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch value := values.(type) {
 | 
				
			||||||
 | 
						case map[string]interface{}:
 | 
				
			||||||
 | 
							return value
 | 
				
			||||||
 | 
						case []interface{}:
 | 
				
			||||||
 | 
							for _, v := range value {
 | 
				
			||||||
 | 
								for key, value := range convertInterfaceToMap(v) {
 | 
				
			||||||
 | 
									attrs[key] = value
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case interface{}:
 | 
				
			||||||
 | 
							reflectValue := reflect.ValueOf(values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							switch reflectValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Map:
 | 
				
			||||||
 | 
								for _, key := range reflectValue.MapKeys() {
 | 
				
			||||||
 | 
									attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								for _, field := range (&Scope{Value: values}).Fields() {
 | 
				
			||||||
 | 
									if !field.IsBlank {
 | 
				
			||||||
 | 
										attrs[field.DBName] = field.Field.Interface()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return attrs
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func equalAsString(a interface{}, b interface{}) bool {
 | 
				
			||||||
 | 
						return toString(a) == toString(b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func toString(str interface{}) string {
 | 
				
			||||||
 | 
						if values, ok := str.([]interface{}); ok {
 | 
				
			||||||
 | 
							var results []string
 | 
				
			||||||
 | 
							for _, value := range values {
 | 
				
			||||||
 | 
								results = append(results, toString(value))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return strings.Join(results, "_")
 | 
				
			||||||
 | 
						} else if bytes, ok := str.([]byte); ok {
 | 
				
			||||||
 | 
							return string(bytes)
 | 
				
			||||||
 | 
						} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
 | 
				
			||||||
 | 
							return fmt.Sprintf("%v", reflectValue.Interface())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func makeSlice(elemType reflect.Type) interface{} {
 | 
				
			||||||
 | 
						if elemType.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
							elemType = elemType.Elem()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						sliceType := reflect.SliceOf(elemType)
 | 
				
			||||||
 | 
						slice := reflect.New(sliceType)
 | 
				
			||||||
 | 
						slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
 | 
				
			||||||
 | 
						return slice.Interface()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func strInSlice(a string, list []string) bool {
 | 
				
			||||||
 | 
						for _, b := range list {
 | 
				
			||||||
 | 
							if b == a {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// getValueFromFields return given fields's value
 | 
				
			||||||
 | 
					func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
 | 
				
			||||||
 | 
						// If value is a nil pointer, Indirect returns a zero Value!
 | 
				
			||||||
 | 
						// Therefor we need to check for a zero value,
 | 
				
			||||||
 | 
						// as FieldByName could panic
 | 
				
			||||||
 | 
						if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
 | 
				
			||||||
 | 
							for _, fieldName := range fieldNames {
 | 
				
			||||||
 | 
								if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
 | 
				
			||||||
 | 
									result := fieldValue.Interface()
 | 
				
			||||||
 | 
									if r, ok := result.(driver.Valuer); ok {
 | 
				
			||||||
 | 
										result, _ = r.Value()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									results = append(results, result)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func addExtraSpaceIfExist(str string) string {
 | 
				
			||||||
 | 
						if str != "" {
 | 
				
			||||||
 | 
							return " " + str
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,98 +0,0 @@
 | 
				
			|||||||
package gorm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"regexp"
 | 
					 | 
				
			||||||
	"runtime"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func fileWithLineNum() string {
 | 
					 | 
				
			||||||
	for i := 2; i < 15; i++ {
 | 
					 | 
				
			||||||
		_, file, line, ok := runtime.Caller(i)
 | 
					 | 
				
			||||||
		if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
 | 
					 | 
				
			||||||
			return fmt.Sprintf("%v:%v", file, line)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func isBlank(value reflect.Value) bool {
 | 
					 | 
				
			||||||
	return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func toSearchableMap(attrs ...interface{}) (result interface{}) {
 | 
					 | 
				
			||||||
	if len(attrs) > 1 {
 | 
					 | 
				
			||||||
		if str, ok := attrs[0].(string); ok {
 | 
					 | 
				
			||||||
			result = map[string]interface{}{str: attrs[1]}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else if len(attrs) == 1 {
 | 
					 | 
				
			||||||
		if attr, ok := attrs[0].(map[string]interface{}); ok {
 | 
					 | 
				
			||||||
			result = attr
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if attr, ok := attrs[0].(interface{}); ok {
 | 
					 | 
				
			||||||
			result = attr
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
 | 
					 | 
				
			||||||
	attrs := map[string]interface{}{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	switch value := values.(type) {
 | 
					 | 
				
			||||||
	case map[string]interface{}:
 | 
					 | 
				
			||||||
		for k, v := range value {
 | 
					 | 
				
			||||||
			attrs[ToDBName(k)] = v
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case []interface{}:
 | 
					 | 
				
			||||||
		for _, v := range value {
 | 
					 | 
				
			||||||
			for key, value := range convertInterfaceToMap(v) {
 | 
					 | 
				
			||||||
				attrs[key] = value
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case interface{}:
 | 
					 | 
				
			||||||
		reflectValue := reflect.ValueOf(values)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		switch reflectValue.Kind() {
 | 
					 | 
				
			||||||
		case reflect.Map:
 | 
					 | 
				
			||||||
			for _, key := range reflectValue.MapKeys() {
 | 
					 | 
				
			||||||
				attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		default:
 | 
					 | 
				
			||||||
			scope := Scope{Value: values}
 | 
					 | 
				
			||||||
			for _, field := range scope.Fields() {
 | 
					 | 
				
			||||||
				if !field.IsBlank && !field.IsIgnored {
 | 
					 | 
				
			||||||
					attrs[field.DBName] = field.Field.Interface()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return attrs
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func toString(str interface{}) string {
 | 
					 | 
				
			||||||
	if values, ok := str.([]interface{}); ok {
 | 
					 | 
				
			||||||
		var results []string
 | 
					 | 
				
			||||||
		for _, value := range values {
 | 
					 | 
				
			||||||
			results = append(results, toString(value))
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return strings.Join(results, "_")
 | 
					 | 
				
			||||||
	} else if bytes, ok := str.([]byte); ok {
 | 
					 | 
				
			||||||
		return string(bytes)
 | 
					 | 
				
			||||||
	} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
 | 
					 | 
				
			||||||
		return fmt.Sprintf("%v", reflectValue.Interface())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func strInSlice(a string, list []string) bool {
 | 
					 | 
				
			||||||
	for _, b := range list {
 | 
					 | 
				
			||||||
		if b == a {
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										30
									
								
								utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								utils_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					package gorm_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/jinzhu/gorm"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestToDBNameGenerateFriendlyName(t *testing.T) {
 | 
				
			||||||
 | 
						var maps = map[string]string{
 | 
				
			||||||
 | 
							"":                          "",
 | 
				
			||||||
 | 
							"ThisIsATest":               "this_is_a_test",
 | 
				
			||||||
 | 
							"PFAndESI":                  "pf_and_esi",
 | 
				
			||||||
 | 
							"AbcAndJkl":                 "abc_and_jkl",
 | 
				
			||||||
 | 
							"EmployeeID":                "employee_id",
 | 
				
			||||||
 | 
							"SKU_ID":                    "sku_id",
 | 
				
			||||||
 | 
							"HTTPAndSMTP":               "http_and_smtp",
 | 
				
			||||||
 | 
							"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
 | 
				
			||||||
 | 
							"UUID":     "uuid",
 | 
				
			||||||
 | 
							"HTTPURL":  "http_url",
 | 
				
			||||||
 | 
							"HTTP_URL": "http_url",
 | 
				
			||||||
 | 
							"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for key, value := range maps {
 | 
				
			||||||
 | 
							if gorm.ToDBName(key) != value {
 | 
				
			||||||
 | 
								t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user