diff --git a/association.go b/association.go deleted file mode 100644 index a73344fe..00000000 --- a/association.go +++ /dev/null @@ -1,377 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - switch relationship.Kind { - case "many_to_many": - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - case "has_many", "has_one": - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - case "belongs_to": - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) - } - - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err - } - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -// setErr set error when the error is not nil. And return Association. -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/association_test.go b/association_test.go deleted file mode 100644 index 60d0cf48..00000000 --- a/association_test.go +++ /dev/null @@ -1,1050 +0,0 @@ -package gorm_test - -import ( - "fmt" - "os" - "reflect" - "sort" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestBelongsTo(t *testing.T) { - post := Post{ - Title: "post belongs to", - Body: "body belongs to", - Category: Category{Name: "Category 1"}, - MainCategory: Category{Name: "Main Category 1"}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - if post.Category.ID == 0 || post.MainCategory.ID == 0 { - t.Errorf("Category's primary key should be updated") - } - - if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { - t.Errorf("post's foreign key should be updated") - } - - // Query - var category1 Category - DB.Model(&post).Association("Category").Find(&category1) - if category1.Name != "Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var mainCategory1 Category - DB.Model(&post).Association("MainCategory").Find(&mainCategory1) - if mainCategory1.Name != "Main Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var category11 Category - DB.Model(&post).Related(&category11) - if category11.Name != "Category 1" { - t.Errorf("Query belongs to relations with Related") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - if DB.Model(&post).Association("MainCategory").Count() != 1 { - t.Errorf("Post's main category count should be 1") - } - - // Append - var category2 = Category{ - Name: "Category 2", - } - DB.Model(&post).Association("Category").Append(&category2) - - if category2.ID == 0 { - t.Errorf("Category should has ID when created with Append") - } - - var category21 Category - DB.Model(&post).Related(&category21) - - if category21.Name != "Category 2" { - t.Errorf("Category should be updated with Append") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Replace - var category3 = Category{ - Name: "Category 3", - } - DB.Model(&post).Association("Category").Replace(&category3) - - if category3.ID == 0 { - t.Errorf("Category should has ID when created with Replace") - } - - var category31 Category - DB.Model(&post).Related(&category31) - if category31.Name != "Category 3" { - t.Errorf("Category should be updated with Replace") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Delete - DB.Model(&post).Association("Category").Delete(&category2) - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not delete any category when Delete a unrelated Category") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should not be reseted when Delete a unrelated Category") - } - - DB.Model(&post).Association("Category").Delete(&category3) - - if post.Category.Name != "" { - t.Errorf("Post's category should be reseted after Delete") - } - - var category41 Category - DB.Model(&post).Related(&category41) - if category41.Name != "" { - t.Errorf("Category should be deleted with Delete") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Delete, but got %v", count) - } - - // Clear - DB.Model(&post).Association("Category").Append(&Category{ - Name: "Category 2", - }) - - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should find category after append") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should has value after Append") - } - - DB.Model(&post).Association("Category").Clear() - - if post.Category.Name != "" { - t.Errorf("Post's category should be cleared after Clear") - } - - if !DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not find any category after Clear") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Clear, but got %v", count) - } - - // Check Association mode with soft delete - category6 := Category{ - Name: "Category 6", - } - DB.Model(&post).Association("Category").Append(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 after Append, but got %v", count) - } - - DB.Delete(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) - } - - if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { - t.Errorf("Post's category is not findable after Delete") - } - - if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { - t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) - } -} - -func TestBelongsToOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileRefer"` - ProfileRefer int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestBelongsToOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Refer string - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` - ProfileID int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOne(t *testing.T) { - user := User{ - Name: "has one", - CreditCard: CreditCard{Number: "411111111111"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Error("Got errors when save user", err.Error()) - } - - if user.CreditCard.UserId.Int64 == 0 { - t.Errorf("CreditCard's foreign key should be updated") - } - - // Query - var creditCard1 CreditCard - DB.Model(&user).Related(&creditCard1) - - if creditCard1.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - var creditCard11 CreditCard - DB.Model(&user).Association("CreditCard").Find(&creditCard11) - - if creditCard11.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Append - var creditcard2 = CreditCard{ - Number: "411111111112", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard2) - - if creditcard2.ID == 0 { - t.Errorf("Creditcard should has ID when created with Append") - } - - var creditcard21 CreditCard - DB.Model(&user).Related(&creditcard21) - if creditcard21.Number != "411111111112" { - t.Errorf("CreditCard should be updated with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Replace - var creditcard3 = CreditCard{ - Number: "411111111113", - } - DB.Model(&user).Association("CreditCard").Replace(&creditcard3) - - if creditcard3.ID == 0 { - t.Errorf("Creditcard should has ID when created with Replace") - } - - var creditcard31 CreditCard - DB.Model(&user).Related(&creditcard31) - if creditcard31.Number != "411111111113" { - t.Errorf("CreditCard should be updated with Replace") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Delete - DB.Model(&user).Association("CreditCard").Delete(&creditcard2) - var creditcard4 CreditCard - DB.Model(&user).Related(&creditcard4) - if creditcard4.Number != "411111111113" { - t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Delete(&creditcard3) - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should delete credit card with Delete") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Delete") - } - - // Clear - var creditcard5 = CreditCard{ - Number: "411111111115", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard5) - - if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should added credit card with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Clear() - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Credit card should be deleted with Clear") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Clear") - } - - // Check Association mode with soft delete - var creditcard6 = CreditCard{ - Number: "411111111116", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 after Append, but got %v", count) - } - - DB.Delete(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { - t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) - } - - if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { - t.Errorf("User's creditcard is not findable after Delete") - } - - if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { - t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) - } -} - -func TestHasOneOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOneOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasMany(t *testing.T) { - post := Post{ - Title: "post has many", - Body: "body has many", - Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - for _, comment := range post.Comments { - if comment.PostId == 0 { - t.Errorf("comment's PostID should be updated") - } - } - - var compareComments = func(comments []Comment, contents []string) bool { - var commentContents []string - for _, comment := range comments { - commentContents = append(commentContents, comment.Content) - } - sort.Strings(commentContents) - sort.Strings(contents) - return reflect.DeepEqual(commentContents, contents) - } - - // Query - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { - t.Errorf("Comment 1 should be saved") - } - - var comments1 []Comment - DB.Model(&post).Association("Comments").Find(&comments1) - if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Association") - } - - var comments11 []Comment - DB.Model(&post).Related(&comments11) - if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Related") - } - - if DB.Model(&post).Association("Comments").Count() != 2 { - t.Errorf("Post's comments count should be 2") - } - - // Append - DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) - - var comments2 []Comment - DB.Model(&post).Related(&comments2) - if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { - t.Errorf("Append new record to has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 3 { - t.Errorf("Post's comments count should be 3 after Append") - } - - // Delete - DB.Model(&post).Association("Comments").Delete(comments11) - - var comments3 []Comment - DB.Model(&post).Related(&comments3) - if !compareComments(comments3, []string{"Comment 3"}) { - t.Errorf("Delete an existing resource for has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 1 { - t.Errorf("Post's comments count should be 1 after Delete 2") - } - - // Replace - DB.Model(&Post{Id: 999}).Association("Comments").Replace() - - var comments4 []Comment - DB.Model(&post).Related(&comments4) - if len(comments4) == 0 { - t.Errorf("Replace for other resource should not clear all comments") - } - - DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) - - var comments41 []Comment - DB.Model(&post).Related(&comments41) - if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { - t.Errorf("Replace has many relations") - } - - // Clear - DB.Model(&Post{Id: 999}).Association("Comments").Clear() - - var comments5 []Comment - DB.Model(&post).Related(&comments5) - if len(comments5) == 0 { - t.Errorf("Clear should not clear all comments") - } - - DB.Model(&post).Association("Comments").Clear() - - var comments51 []Comment - DB.Model(&post).Related(&comments51) - if len(comments51) != 0 { - t.Errorf("Clear has many relations") - } - - // Check Association mode with soft delete - var comment6 = Comment{ - Content: "comment 6", - } - DB.Model(&post).Association("Comments").Append(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 after Append, but got %v", count) - } - - DB.Delete(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 0 { - t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) - } - - var comments6 []Comment - if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { - t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) - } - - if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) - } - - var comments61 []Comment - if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) - } -} - -func TestHasManyOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile []Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasManyOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestManyToMany(t *testing.T) { - DB.Raw("delete from languages") - var languages = []Language{{Name: "ZH"}, {Name: "EN"}} - user := User{Name: "Many2Many", Languages: languages} - DB.Save(&user) - - // Query - var newLanguages []Language - DB.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Query many to many relations") - } - - DB.Model(&user).Association("Languages").Find(&newLanguages) - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Should be able to find many to many relations") - } - - if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { - t.Errorf("Count should return correct result") - } - - // Append - DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) - if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { - t.Errorf("New record should be saved when append") - } - - languageA := Language{Name: "AA"} - DB.Save(&languageA) - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) - - languageC := Language{Name: "CC"} - DB.Save(&languageC) - DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) - - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) - - totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { - t.Errorf("All appended languages should be saved") - } - - // Delete - user.Languages = []Language{} - DB.Model(&user).Association("Languages").Find(&user.Languages) - - var language Language - DB.Where("name = ?", "EE").First(&language) - DB.Model(&user).Association("Languages").Delete(language, &language) - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { - t.Errorf("Relations should be deleted with Delete") - } - if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { - t.Errorf("Language EE should not be deleted") - } - - DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) - - user2 := User{Name: "Many2Many_User2", Languages: languages} - DB.Save(&user2) - - DB.Model(&user).Association("Languages").Delete(languages, &languages) - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { - t.Errorf("Relations should be deleted with Delete") - } - - if DB.Model(&user2).Association("Languages").Count() == 0 { - t.Errorf("Other user's relations should not be deleted") - } - - // Replace - var languageB Language - DB.Where("name = ?", "BB").First(&languageB) - DB.Model(&user).Association("Languages").Replace(languageB) - if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { - t.Errorf("Relations should be replaced") - } - - DB.Model(&user).Association("Languages").Replace() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be replaced with empty") - } - - DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) - if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { - t.Errorf("Relations should be replaced") - } - - // Clear - DB.Model(&user).Association("Languages").Clear() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be cleared") - } - - // Check Association mode with soft delete - var language6 = Language{ - Name: "language 6", - } - DB.Model(&user).Association("Languages").Append(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 after Append, but got %v", count) - } - - DB.Delete(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 0 { - t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) - } - - var languages6 []Language - if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { - t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) - } - - if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) - } - - var languages61 []Language - if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) - } -} - -func TestRelated(t *testing.T) { - user := User{ - Name: "jinzhu", - BillingAddress: Address{Address1: "Billing Address - Address 1"}, - ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, - CreditCard: CreditCard{Number: "1234567890"}, - Company: Company{Name: "company1"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Errorf("No error should happen when saving user") - } - - if user.CreditCard.ID == 0 { - t.Errorf("After user save, credit card should have id") - } - - if user.BillingAddress.ID == 0 { - t.Errorf("After user save, billing address should have id") - } - - if user.Emails[0].Id == 0 { - t.Errorf("After user save, billing address should have id") - } - - var emails []Email - DB.Model(&user).Related(&emails) - if len(emails) != 2 { - t.Errorf("Should have two emails") - } - - var emails2 []Email - DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) - if len(emails2) != 1 { - t.Errorf("Should have two emails") - } - - var emails3 []*Email - DB.Model(&user).Related(&emails3) - if len(emails3) != 2 { - t.Errorf("Should have two emails") - } - - var user1 User - DB.Model(&user).Related(&user1.Emails) - if len(user1.Emails) != 2 { - t.Errorf("Should have only one email match related condition") - } - - var address1 Address - DB.Model(&user).Related(&address1, "BillingAddressId") - if address1.Address1 != "Billing Address - Address 1" { - t.Errorf("Should get billing address from user correctly") - } - - user1 = User{} - DB.Model(&address1).Related(&user1, "BillingAddressId") - if DB.NewRecord(user1) { - t.Errorf("Should get user from address correctly") - } - - var user2 User - DB.Model(&emails[0]).Related(&user2) - if user2.Id != user.Id || user2.Name != user.Name { - t.Errorf("Should get user from email correctly") - } - - var creditcard CreditCard - var user3 User - DB.First(&creditcard, "number = ?", "1234567890") - DB.Model(&creditcard).Related(&user3) - if user3.Id != user.Id || user3.Name != user.Name { - t.Errorf("Should get user from credit card correctly") - } - - if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { - t.Errorf("RecordNotFound for Related") - } - - var company Company - if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { - t.Errorf("RecordNotFound for Related") - } -} - -func TestForeignKey(t *testing.T) { - for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Email{}).GetStructFields() { - for _, foreignKey := range []string{"UserId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Post{}).GetStructFields() { - for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { - for _, foreignKey := range []string{"PostId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } -} - -func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { - // sqlite does not support ADD CONSTRAINT in ALTER TABLE - return - } - targetScope := DB.NewScope(target) - targetTableName := targetScope.TableName() - modelScope := DB.NewScope(source) - modelField, ok := modelScope.FieldByName(sourceFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName)) - } - targetField, ok := targetScope.FieldByName(targetFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName)) - } - dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error - if err != nil { - t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) - } -} - -func TestLongForeignKey(t *testing.T) { - testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID") -} - -func TestLongForeignKeyWithShortDest(t *testing.T) { - testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") -} - -func TestHasManyChildrenWithOneStruct(t *testing.T) { - category := Category{ - Name: "main", - Categories: []Category{ - {Name: "sub1"}, - {Name: "sub2"}, - }, - } - - DB.Save(&category) -} - -func TestAutoSaveBelongsToAssociation(t *testing.T) { - type Company struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Name string - CompanyID uint - Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` - } - - DB.Where("name = ?", "auto_save_association").Delete(&Company{}) - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - - if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_association should not have been saved when autosave is false") - } - - // if foreign key is set, this should be saved even if association isn't - company := Company{Name: "auto_save_association"} - DB.Save(&company) - - company.Name = "auto_save_association_new_name" - user := User{Name: "jinzhu", Company: company} - - DB.Save(&user) - - if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { - t.Errorf("User's foreign key should have been saved") - } - - user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} - DB.Set("gorm:association_autocreate", true).Save(&user2) - if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_association_2 should been created when autocreate is true") - } - - user2.Company.Name = "auto_save_association_2_newname" - DB.Set("gorm:association_autoupdate", true).Save(&user2) - - if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { - t.Errorf("Company should been updated") - } -} - -func TestAutoSaveHasOneAssociation(t *testing.T) { - type Company struct { - gorm.Model - UserID uint - Name string - } - - type User struct { - gorm.Model - Name string - Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` - } - - DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) - - if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") - } - - company := Company{Name: "auto_save_has_one_association"} - DB.Save(&company) - - company.Name = "auto_save_has_one_association_new_name" - user := User{Name: "jinzhu", Company: company} - - DB.Save(&user) - - if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if user.Company.UserID == 0 { - t.Errorf("UserID should be assigned") - } - - company.Name = "auto_save_has_one_association_2_new_name" - DB.Set("gorm:association_autoupdate", true).Save(&user) - - if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { - t.Errorf("Company should been updated") - } - - user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} - DB.Set("gorm:association_autocreate", true).Save(&user2) - if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") - } -} - -func TestAutoSaveMany2ManyAssociation(t *testing.T) { - type Company struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Name string - Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` - } - - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) - - if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") - } - - company := Company{Name: "auto_save_m2m_association"} - DB.Save(&company) - - company.Name = "auto_save_m2m_association_new_name" - user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} - - DB.Save(&user) - - if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not been created") - } - - if DB.Model(&user).Association("Companies").Count() != 1 { - t.Errorf("Relationship should been saved") - } - - DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) - - if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should been updated") - } - - if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company should been created") - } - - if DB.Model(&user).Association("Companies").Count() != 2 { - t.Errorf("Relationship should been updated") - } -} diff --git a/callback.go b/callback.go deleted file mode 100644 index 1f0e3c79..00000000 --- a/callback.go +++ /dev/null @@ -1,250 +0,0 @@ -package gorm - -import "fmt" - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{logger: nopLogger{}} - -// Callback is a struct that contains all CRUD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - logger logger - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - logger logger - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone(logger logger) *Callback { - return &Callback{ - logger: logger, - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) - cp.before = "gorm:row_query" - } - } - - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("CreatedAt", now) -// scope.SetColumn("UpdatedAt", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind { - if p.remove { - callback = nil - } else { - callback = *p.processor - } - } - } - return -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CRUD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/callback_create.go b/callback_create.go deleted file mode 100644 index c4d25f37..00000000 --- a/callback_create.go +++ /dev/null @@ -1,197 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := scope.db.nowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal && !field.IsIgnored { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - insertModifier string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - if str, ok := scope.Get("gorm:insert_modifier"); ok { - insertModifier = strings.ToUpper(fmt.Sprint(str)) - if insertModifier == "INTO" { - insertModifier = "" - } - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) - var lastInsertIDReturningSuffix string - if lastInsertIDOutputInterstitial == "" { - lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - } - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v %v%v%v", - addExtraSpaceIfExist(insertModifier), - quotedTableName, - scope.Dialect().DefaultValueStr(), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", - addExtraSpaceIfExist(insertModifier), - scope.QuotedTableName(), - strings.Join(columns, ","), - addExtraSpaceIfExist(lastInsertIDOutputInterstitial), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql: no primaryField - if primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: lastInsertID implemention for majority of dialects - if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - return - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/callback_delete.go b/callback_delete.go deleted file mode 100644 index 48b97acb..00000000 --- a/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(scope.db.nowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/callback_query.go b/callback_query.go deleted file mode 100644 index 544afd63..00000000 --- a/callback_query.go +++ /dev/null @@ -1,109 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - //we are only preloading relations, dont touch base model - if _, skip := scope.InstanceGet("gorm:only_preload"); skip { - return - } - - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/callback_query_preload.go b/callback_query_preload.go deleted file mode 100644 index a936180a..00000000 --- a/callback_query_preload.go +++ /dev/null @@ -1,410 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - if ap, ok := scope.Get("gorm:auto_preload"); ok { - // If gorm:auto_preload IS NOT a bool then auto preload. - // Else if it IS a bool, use the value - if apb, ok := ap.(bool); !ok { - autoPreload(scope) - } else if apb { - autoPreload(scope) - } - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettingsGet("PRELOAD"); ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - foreignValuesToResults := make(map[string]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) - foreignValuesToResults[foreignValues] = result - } - for j := 0; j < indirectScopeValue.Len(); j++ { - indirectValue := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) - if result, found := foreignValuesToResults[valueString]; found { - indirectValue.FieldByName(field.Name).Set(result) - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - foreignFieldToObjects := make(map[string][]*reflect.Value) - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) - foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) - } - } - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) - if objects, found := foreignFieldToObjects[valueString]; found { - for _, object := range objects { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - scope.New(elem.Addr().Interface()). - InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - - for source, fields := range fieldsSourceMap { - for _, f := range fields { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if f.Len() != 0 { - continue - } - - v := reflect.MakeSlice(f.Type(), 0, 0) - if len(linkHash[source]) > 0 { - v = reflect.Append(f, linkHash[source]...) - } - - f.Set(v) - } - } -} diff --git a/callback_row_query.go b/callback_row_query.go deleted file mode 100644 index 323b1605..00000000 --- a/callback_row_query.go +++ /dev/null @@ -1,41 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" -) - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/callback_save.go b/callback_save.go deleted file mode 100644 index 3b4e0589..00000000 --- a/callback_save.go +++ /dev/null @@ -1,170 +0,0 @@ -package gorm - -import ( - "reflect" - "strings" -) - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { - checkTruth := func(value interface{}) bool { - if v, ok := value.(bool); ok && !v { - return false - } - - if v, ok := value.(string); ok { - v = strings.ToLower(v) - return v == "true" - } - - return true - } - - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if r = field.Relationship; r != nil { - autoUpdate, autoCreate, saveReference = true, true, true - - if value, ok := scope.Get("gorm:save_associations"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } - - if value, ok := scope.Get("gorm:association_autoupdate"); ok { - autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { - autoUpdate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_autocreate"); ok { - autoCreate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { - autoCreate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_save_reference"); ok { - saveReference = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { - saveReference = checkTruth(value) - } - } - } - - return -} - -func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - newScope := scope.New(fieldValue) - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if saveReference { - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(newDB.Save(elem).Error) - } - } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) - } - - if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/callback_system_test.go b/callback_system_test.go deleted file mode 100644 index 2482eda4..00000000 --- a/callback_system_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package gorm - -import ( - "reflect" - "runtime" - "strings" - "testing" -) - -func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { - var names []string - for _, f := range funcs { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") - names = append(names, fnames[len(fnames)-1]) - } - return reflect.DeepEqual(names, fnames) -} - -func create(s *Scope) {} -func beforeCreate1(s *Scope) {} -func beforeCreate2(s *Scope) {} -func afterCreate1(s *Scope) {} -func afterCreate2(s *Scope) {} - -func TestRegisterCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} - - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("before_create2", beforeCreate2) - callback.Create().Register("create", create) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Register("after_create2", afterCreate2) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback") - } -} - -func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{logger: defaultLogger} - callback1.Create().Register("before_create1", beforeCreate1) - callback1.Create().Register("create", create) - callback1.Create().Register("after_create1", afterCreate1) - callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{logger: defaultLogger} - - callback2.Update().Register("create", create) - callback2.Update().Before("create").Register("before_create1", beforeCreate1) - callback2.Update().After("after_create2").Register("after_create1", afterCreate1) - callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callback2.Update().Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } -} - -func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{logger: defaultLogger} - - callback1.Query().Before("after_create1").After("before_create1").Register("create", create) - callback1.Query().Register("before_create1", beforeCreate1) - callback1.Query().Register("after_create1", afterCreate1) - - if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{logger: defaultLogger} - - callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callback2.Delete().Register("after_create1", afterCreate1) - callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback with order") - } -} - -func replaceCreate(s *Scope) {} - -func TestReplaceCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Replace("create", replaceCreate) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { - t.Errorf("replace callback") - } -} - -func TestRemoveCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Remove("create") - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { - t.Errorf("remove callback") - } -} diff --git a/callback_update.go b/callback_update.go deleted file mode 100644 index 699e534b..00000000 --- a/callback_update.go +++ /dev/null @@ -1,121 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", scope.db.nowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - // Sort the column names so that the generated SQL is the same every time. - updateMap := updateAttrs.(map[string]interface{}) - var columns []string - for c := range updateMap { - columns = append(columns, c) - } - sort.Strings(columns) - - for _, column := range columns { - value := updateMap[column] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { - if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/callbacks_test.go b/callbacks_test.go deleted file mode 100644 index bebd0e38..00000000 --- a/callbacks_test.go +++ /dev/null @@ -1,249 +0,0 @@ -package gorm_test - -import ( - "errors" - "reflect" - "testing" - - "github.com/jinzhu/gorm" -) - -func (s *Product) BeforeCreate() (err error) { - if s.Code == "Invalid" { - err = errors.New("invalid product") - } - s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 - return -} - -func (s *Product) BeforeUpdate() (err error) { - if s.Code == "dont_update" { - err = errors.New("can't update") - } - s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 - return -} - -func (s *Product) BeforeSave() (err error) { - if s.Code == "dont_save" { - err = errors.New("can't save") - } - s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 - return -} - -func (s *Product) AfterFind() { - s.AfterFindCallTimes = s.AfterFindCallTimes + 1 -} - -func (s *Product) AfterCreate(tx *gorm.DB) { - tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) -} - -func (s *Product) AfterUpdate() { - s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 -} - -func (s *Product) AfterSave() (err error) { - if s.Code == "after_save_error" { - err = errors.New("can't save") - } - s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 - return -} - -func (s *Product) BeforeDelete() (err error) { - if s.Code == "dont_delete" { - err = errors.New("can't delete") - } - s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 - return -} - -func (s *Product) AfterDelete() (err error) { - if s.Code == "after_delete_error" { - err = errors.New("can't delete") - } - s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 - return -} - -func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} -} - -func TestRunCallbacks(t *testing.T) { - p := Product{Code: "unique_code", Price: 100} - DB.Save(&p) - - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { - t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) - } - - p.Price = 200 - DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { - t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - var products []Product - DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should work with slice") - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { - t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) - } - - DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { - t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { - t.Errorf("Can't find a deleted record") - } -} - -func TestCallbacksWithErrors(t *testing.T) { - p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { - t.Errorf("An error from before create callbacks happened when create with invalid value") - } - - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { - t.Errorf("Should not save record that have errors") - } - - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { - t.Errorf("An error from after create callbacks happened when create with invalid value") - } - - p2 := Product{Code: "update_callback", Price: 100} - DB.Save(&p2) - - p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before update callbacks happened when update with invalid value") - } - - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before save callbacks happened when update with invalid value") - } - - p3 := Product{Code: "dont_delete", Price: 100} - DB.Save(&p3) - if DB.Delete(&p3).Error == nil { - t.Errorf("An error from before delete callbacks happened when delete") - } - - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { - t.Errorf("An error from before delete callbacks happened") - } - - p4 := Product{Code: "after_save_error", Price: 100} - DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { - t.Errorf("Record should be reverted if get an error in after save callback") - } - - p5 := Product{Code: "after_delete_error", Price: 100} - DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record should be found") - } - - DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") - } -} - -func TestGetCallback(t *testing.T) { - scope := DB.NewScope(nil) - - if DB.Callback().Create().Get("gorm:test_callback") != nil { - t.Errorf("`gorm:test_callback` should be nil") - } - - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) - callback := DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { - t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) - } - - DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) - callback = DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { - t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) - } - - DB.Callback().Create().Remove("gorm:test_callback") - if DB.Callback().Create().Get("gorm:test_callback") != nil { - t.Errorf("`gorm:test_callback` should be nil") - } - - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) - callback = DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { - t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) - } -} - -func TestUseDefaultCallback(t *testing.T) { - createCallbackName := "gorm:test_use_default_callback_for_create" - gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { - // nop - }) - if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { - t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) - } - gorm.DefaultCallback.Create().Remove(createCallbackName) - if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { - t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) - } - - updateCallbackName := "gorm:test_use_default_callback_for_update" - scopeValueName := "gorm:test_use_default_callback_for_update_value" - gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { - scope.Set(scopeValueName, 1) - }) - gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { - scope.Set(scopeValueName, 2) - }) - - scope := DB.NewScope(nil) - callback := gorm.DefaultCallback.Update().Get(updateCallbackName) - callback(scope) - if v, ok := scope.Get(scopeValueName); !ok || v != 2 { - t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) - } -} diff --git a/create_test.go b/create_test.go deleted file mode 100644 index c80bdcbb..00000000 --- a/create_test.go +++ /dev/null @@ -1,288 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "testing" - "time" - - "github.com/jinzhu/now" -) - -func TestCreate(t *testing.T) { - float := 35.03554004971999 - now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Save(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - if DB.NewRecord(user) || DB.NewRecord(&user) { - t.Error("User should not new record after save") - } - - var newUser User - if err := DB.First(&newUser, user.Id).Error; err != nil { - t.Errorf("No error should happen, but got %v", err) - } - - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { - t.Errorf("User's PasswordHash should be saved ([]byte)") - } - - if newUser.Age != 18 { - t.Errorf("User's Age should be saved (int)") - } - - if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) - } - - if newUser.Latitude != float { - t.Errorf("Float64 should not be changed after save") - } - - if user.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - if newUser.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - DB.Model(user).Update("name", "create_user_new_name") - DB.First(&user, user.Id) - if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { - t.Errorf("CreatedAt should not be changed after update") - } -} - -func TestCreateEmptyStrut(t *testing.T) { - type EmptyStruct struct { - ID uint - } - DB.AutoMigrate(&EmptyStruct{}) - - if err := DB.Create(&EmptyStruct{}).Error; err != nil { - t.Errorf("No error should happen when creating user, but got %v", err) - } -} - -func TestCreateWithExistingTimestamp(t *testing.T) { - user := User{Name: "CreateUserExistingTimestamp"} - - timeA := now.MustParse("2016-01-01") - user.CreatedAt = timeA - user.UpdatedAt = timeA - DB.Save(&user) - - if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt should not be changed") - } - - if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt should not be changed") - } - - var newUser User - DB.First(&newUser, user.Id) - - if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt should not be changed") - } - - if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt should not be changed") - } -} - -func TestCreateWithNowFuncOverride(t *testing.T) { - user1 := User{Name: "CreateUserTimestampOverride"} - - timeA := now.MustParse("2016-01-01") - - // do DB.New() because we don't want this test to affect other tests - db1 := DB.New() - // set the override to use static timeA - db1.SetNowFuncOverride(func() time.Time { - return timeA - }) - // call .New again to check the override is carried over as well during clone - db1 = db1.New() - - db1.Save(&user1) - - if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt be using the nowFuncOverride") - } - if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt be using the nowFuncOverride") - } - - // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set - // to make sure that setting it only affected the above instance - - user2 := User{Name: "CreateUserTimestampOverrideNoMore"} - - db2 := DB.New() - - db2.Save(&user2) - - if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt no longer be using the nowFuncOverride") - } - if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt no longer be using the nowFuncOverride") - } -} - -type AutoIncrementUser struct { - User - Sequence uint `gorm:"AUTO_INCREMENT"` -} - -func TestCreateWithAutoIncrement(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") - } - - DB.AutoMigrate(&AutoIncrementUser{}) - - user1 := AutoIncrementUser{} - user2 := AutoIncrementUser{} - - DB.Create(&user1) - DB.Create(&user2) - - if user2.Sequence-user1.Sequence != 1 { - t.Errorf("Auto increment should apply on Sequence") - } -} - -func TestCreateWithNoGORMPrimayKey(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { - t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") - } - - jt := JoinTable{From: 1, To: 2} - err := DB.Create(&jt).Error - if err != nil { - t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) - } -} - -func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create a record without std primary key") - } - - if animal.Counter == 0 { - t.Errorf("No std primary key should be filled value after create") - } - - if animal.Name != "Ferdinand" { - t.Errorf("Default value should be overrided") - } - - // Test create with default value not overrided - an := Animal{From: "nerdz"} - - if DB.Save(&an).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") - } - - // We must fetch the value again, to have the default fields updated - // (We can't do this in the update statements, since sql default can be expressions - // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" - DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) - - if an.Name != "galeone" { - t.Errorf("Default value should fill the field. But got %v", an.Name) - } -} - -func TestAnonymousScanner(t *testing.T) { - user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_scanner") - if user2.Role.Name != "admin" { - t.Errorf("Should be able to get anonymous scanner") - } - - if !user2.Role.IsAdmin() { - t.Errorf("Should be able to get anonymous scanner") - } -} - -func TestAnonymousField(t *testing.T) { - user := User{Name: "anonymous_field", Company: Company{Name: "company"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_field") - DB.Model(&user2).Related(&user2.Company) - if user2.Company.Name != "company" { - t.Errorf("Should be able to get anonymous field") - } -} - -func TestSelectWithCreate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_create") - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name != user.Name || queryuser.Age == user.Age { - t.Errorf("Should only create users with name column") - } - - if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || - queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { - t.Errorf("Should only create selected relationships") - } -} - -func TestOmitWithCreate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_create") - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name == user.Name || queryuser.Age != user.Age { - t.Errorf("Should only create users with age column") - } - - if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || - queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { - t.Errorf("Should not create omitted relationships") - } -} - -func TestCreateIgnore(t *testing.T) { - float := 35.03554004971999 - now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Create(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil { - t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ") - } -} diff --git a/customize_column_test.go b/customize_column_test.go deleted file mode 100644 index c236ac24..00000000 --- a/customize_column_test.go +++ /dev/null @@ -1,357 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date *time.Time `gorm:"column:mapped_time"` -} - -// Make sure an ignored field does not interfere with another field's custom -// column name that matches the ignored field. -type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` -} - -func TestCustomizeColumn(t *testing.T) { - col := "mapped_name" - DB.DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { - t.Errorf("CustomizeColumn should have column %s", col) - } - - col = "mapped_id" - if scope.PrimaryKey() != col { - t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) - } - - expected := "foo" - now := time.Now() - cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, 666) - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { - t.Errorf("Should not raise error: %s", err) - } -} - -type CustomizePerson struct { - IdPerson string `gorm:"column:idPerson;primary_key:true"` - Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` -} - -type CustomizeAccount struct { - IdAccount string `gorm:"column:idAccount;primary_key:true"` - Name string -} - -func TestManyToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") - DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) - - account := CustomizeAccount{IdAccount: "account", Name: "id1"} - person := CustomizePerson{ - IdPerson: "person", - Accounts: []CustomizeAccount{account}, - } - - if err := DB.Create(&account).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if err := DB.Create(&person).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - var person1 CustomizePerson - scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { - t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) - } - - if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { - t.Errorf("should preload correct accounts") - } -} - -type CustomizeUser struct { - gorm.Model - Email string `sql:"column:email_address"` -} - -type CustomizeInvitation struct { - gorm.Model - Address string `sql:"column:invitation"` - Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` -} - -func TestOneToOneWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) - DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) - - user := CustomizeUser{ - Email: "hello@example.com", - } - invitation := CustomizeInvitation{ - Address: "hello@example.com", - } - - DB.Create(&user) - DB.Create(&invitation) - - var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if invitation2.Person.Email != user.Email { - t.Errorf("Should preload one to one relation with customize foreign keys") - } -} - -type PromotionDiscount struct { - gorm.Model - Name string - Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` - Rule *PromotionRule `gorm:"ForeignKey:discount_id"` - Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` -} - -type PromotionBenefit struct { - gorm.Model - Name string - PromotionID uint - Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` -} - -type PromotionCoupon struct { - gorm.Model - Code string - DiscountID uint - Discount PromotionDiscount -} - -type PromotionRule struct { - gorm.Model - Name string - Begin *time.Time - End *time.Time - DiscountID uint - Discount *PromotionDiscount -} - -func TestOneToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) - - discount := PromotionDiscount{ - Name: "Happy New Year", - Coupons: []*PromotionCoupon{ - {Code: "newyear1"}, - {Code: "newyear2"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Coupons) != 2 { - t.Errorf("should find two coupons") - } - - var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if coupon.Discount.Name != "Happy New Year" { - t.Errorf("should preload discount from coupon") - } -} - -func TestHasOneWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) - - var begin = time.Now() - var end = time.Now().Add(24 * time.Hour) - discount := PromotionDiscount{ - Name: "Happy New Year 2", - Rule: &PromotionRule{ - Name: "time_limited", - Begin: &begin, - End: &end, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { - t.Errorf("Should be able to preload Rule") - } - - var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if rule.Discount.Name != "Happy New Year 2" { - t.Errorf("should preload discount from rule") - } -} - -func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) - - discount := PromotionDiscount{ - Name: "Happy New Year 3", - Benefits: []PromotionBenefit{ - {Name: "free cod"}, - {Name: "free shipping"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Benefits) != 2 { - t.Errorf("should find two benefits") - } - - var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if benefit.Discount.Name != "Happy New Year 3" { - t.Errorf("should preload discount from coupon") - } -} - -type SelfReferencingUser struct { - gorm.Model - Name string - Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` -} - -func TestSelfReferencingMany2ManyColumn(t *testing.T) { - DB.DropTable(&SelfReferencingUser{}, "UserFriends") - DB.AutoMigrate(&SelfReferencingUser{}) - if !DB.HasTable("UserFriends") { - t.Errorf("auto migrate error, table UserFriends should be created") - } - - friend1 := SelfReferencingUser{Name: "friend1_m2m"} - if err := DB.Create(&friend1).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - friend2 := SelfReferencingUser{Name: "friend2_m2m"} - if err := DB.Create(&friend2).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - user := SelfReferencingUser{ - Name: "self_m2m", - Friends: []*SelfReferencingUser{&friend1, &friend2}, - } - - if err := DB.Create(&user).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if DB.Model(&user).Association("Friends").Count() != 2 { - t.Errorf("Should find created friends correctly") - } - - var count int - if err := DB.Table("UserFriends").Count(&count).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - if count == 0 { - t.Errorf("table UserFriends should have records") - } - - var newUser = SelfReferencingUser{} - - if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if len(newUser.Friends) != 2 { - t.Errorf("Should preload created frineds for self reference m2m") - } - - DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) - if DB.Model(&user).Association("Friends").Count() != 3 { - t.Errorf("Should find created friends correctly") - } - - DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) - if DB.Model(&user).Association("Friends").Count() != 1 { - t.Errorf("Should find created friends correctly") - } - - friend := SelfReferencingUser{} - DB.Model(&newUser).Association("Friends").Find(&friend) - if friend.Name != "friend4_m2m" { - t.Errorf("Should find created friends correctly") - } - - DB.Model(&newUser).Association("Friends").Delete(friend) - if DB.Model(&user).Association("Friends").Count() != 0 { - t.Errorf("All friends should be deleted") - } -} diff --git a/delete_test.go b/delete_test.go deleted file mode 100644 index 043641f7..00000000 --- a/delete_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" -) - -func TestDelete(t *testing.T) { - user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if err := DB.Delete(&user1).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } - - if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("Other users that not deleted should be found-able") - } -} - -func TestInlineDelete(t *testing.T) { - user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if DB.Delete(&User{}, user1.Id).Error != nil { - t.Errorf("No error should happen when delete a record") - } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } -} - -func TestSoftDelete(t *testing.T) { - type User struct { - Id int64 - Name string - DeletedAt *time.Time - } - DB.AutoMigrate(&User{}) - - user := User{Name: "soft_delete"} - DB.Save(&user) - DB.Delete(&user) - - if DB.First(&User{}, "name = ?", user.Name).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&user) - if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} - -func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) { - creditCard := CreditCard{Number: "411111111234567"} - DB.Save(&creditCard) - DB.Delete(&creditCard) - - if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" { - t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`") - } - - if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&creditCard) - if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index 749587f4..00000000 --- a/dialect.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db SQLCommon) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) (string, error) - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` - LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - // DefaultValueStr - DefaultValueStr() string - - // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference - BuildKeyName(kind, tableName string, fields ...string) string - - // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect - NormalizeIndexAndColumn(indexName, columnName string) (string, string) - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db SQLCommon) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// GetDialect gets the dialect for the specified dialect name -func GetDialect(name string) (dialect Dialect, ok bool) { - dialect, ok = dialectsMap[name] - return -} - -// ParseFieldStructForDialect get field's sql data type -var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var ( - reflectType = field.Struct.Type - dataType, _ = field.TagSettingsGet("TYPE") - ) - - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - if gormDataType, ok := fieldValue.Interface().(interface { - GormDataType(Dialect) string - }); ok { - dataType = gormDataType.GormDataType(dialect) - } - - // Get scanner's real value - if dataType == "" { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - } - - // Default Size - if num, ok := field.TagSettingsGet("SIZE"); ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - notNull, _ := field.TagSettingsGet("NOT NULL") - unique, _ := field.TagSettingsGet("UNIQUE") - additionalType = notNull + " " + unique - if value, ok := field.TagSettingsGet("DEFAULT"); ok { - additionalType = additionalType + " DEFAULT " + value - } - - if value, ok := field.TagSettingsGet("COMMENT"); ok { - additionalType = additionalType + " COMMENT " + value - } - - return fieldValue, dataType, size, strings.TrimSpace(additionalType) -} - -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} diff --git a/dialect_common.go b/dialect_common.go deleted file mode 100644 index d549510c..00000000 --- a/dialect_common.go +++ /dev/null @@ -1,196 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db SQLCommon - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db SQLCommon) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return strings.ToLower(value) != "false" - } - return field.IsPrimaryKey -} - -func (s *commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) - return err -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -// LimitAndOffsetSQL return generated SQL with Limit and Offset -func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - if parsedLimit, err := s.parseInt(limit); err != nil { - return "", err - } else if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := s.parseInt(offset); err != nil { - return "", err - } else if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (commonDialect) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference -func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = keyNameRegex.ReplaceAllString(keyName, "_") - return keyName -} - -// NormalizeIndexAndColumn returns argument's index name and column name without doing anything -func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - return indexName, columnName -} - -func (commonDialect) parseInt(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) -} - -// IsByteArrayOrSlice returns true of the reflected value is an array or slice -func IsByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} diff --git a/dialect_mysql.go b/dialect_mysql.go deleted file mode 100644 index b4467ffa..00000000 --- a/dialect_mysql.go +++ /dev/null @@ -1,246 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "database/sql" - "fmt" - "reflect" - "regexp" - "strings" - "time" - "unicode/utf8" -) - -var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { - field.TagSettingsDelete("AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint AUTO_INCREMENT" - } else { - sqlType = "tinyint" - } - case reflect.Int, reflect.Int16, reflect.Int32: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint unsigned AUTO_INCREMENT" - } else { - sqlType = "tinyint unsigned" - } - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - precision := "" - if p, ok := field.TagSettingsGet("PRECISION"); ok { - precision = fmt.Sprintf("(%s)", p) - } - - if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { - sqlType = fmt.Sprintf("DATETIME%v", precision) - } else { - sqlType = fmt.Sprintf("DATETIME%v NULL", precision) - } - } - default: - if IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - parsedLimit, err := s.parseInt(limit) - if err != nil { - return "", err - } - if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - - if offset != nil { - parsedOffset, err := s.parseInt(offset) - if err != nil { - return "", err - } - if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - } - } - return -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) HasTable(tableName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - var name string - // allow mysql database name with '-' character - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { - if err == sql.ErrNoRows { - return false - } - panic(err) - } else { - return true - } -} - -func (s mysql) HasIndex(tableName string, indexName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) HasColumn(tableName string, columnName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} - -// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed -func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - submatch := mysqlIndexRegex.FindStringSubmatch(indexName) - if len(submatch) != 3 { - return indexName, columnName - } - indexName = submatch[1] - columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) - return indexName, columnName -} - -func (mysql) DefaultValueStr() string { - return "VALUES()" -} diff --git a/dialect_postgres.go b/dialect_postgres.go deleted file mode 100644 index d2df3131..00000000 --- a/dialect_postgres.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "encoding/json" - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) - RegisterDialect("cloudsqlpostgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (s *postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint32, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettingsGet("SIZE"); !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "bytea" - - if isUUID(dataValue) { - sqlType = "uuid" - } - - if isJSON(dataValue) { - sqlType = "jsonb" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { - return "" -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func isJSON(value reflect.Value) bool { - _, ok := value.Interface().(json.RawMessage) - return ok -} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go deleted file mode 100644 index 5f96c363..00000000 --- a/dialect_sqlite3.go +++ /dev/null @@ -1,107 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (s *sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go deleted file mode 100644 index a516ed4a..00000000 --- a/dialects/mssql/mssql.go +++ /dev/null @@ -1,253 +0,0 @@ -package mssql - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "fmt" - "reflect" - "strconv" - "strings" - "time" - - // Importing mssql driver package only in dialect file, otherwide not needed - _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" -) - -func setIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) - scope.InstanceSet("mssql:identity_insert_on", true) - } - } - } -} - -func turnOffIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) - } - } -} - -func init() { - gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) - gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) - gorm.RegisterDialect("mssql", &mssql{}) -} - -type mssql struct { - db gorm.SQLCommon - gorm.DefaultForeignKeyNamer -} - -func (mssql) GetName() string { - return "mssql" -} - -func (s *mssql) SetDB(db gorm.SQLCommon) { - s.db = db -} - -func (mssql) BindVar(i int) string { - return "$$$" // ? -} - -func (mssql) Quote(key string) string { - return fmt.Sprintf(`[%s]`, key) -} - -func (s *mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int IDENTITY(1,1)" - } else { - sqlType = "int" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint IDENTITY(1,1)" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "float" - case reflect.String: - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("nvarchar(%d)", size) - } else { - sqlType = "nvarchar(max)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetimeoffset" - } - default: - if gorm.IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "varbinary(max)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return value != "FALSE" - } - return field.IsPrimaryKey -} - -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow(`SELECT count(*) - FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id - inner join information_schema.tables as I on I.TABLE_NAME = T.name - WHERE F.name = ? - AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) - return count > 0 -} - -func (s mssql) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - -func parseInt(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) -} - -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if offset != nil { - if parsedOffset, err := parseInt(offset); err != nil { - return "", err - } else if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) - } - } - if limit != nil { - if parsedLimit, err := parseInt(limit); err != nil { - return "", err - } else if parsedLimit >= 0 { - if sql == "" { - // add default zero offset - sql += " OFFSET 0 ROWS" - } - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) - } - } - return -} - -func (mssql) SelectFromDummyTable() string { - return "" -} - -func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - if len(columns) == 0 { - // No OUTPUT to query - return "" - } - return fmt.Sprintf("OUTPUT Inserted.%v", columnName) -} - -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id - return "; SELECT SCOPE_IDENTITY()" -} - -func (mssql) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// NormalizeIndexAndColumn returns argument's index name and column name without doing anything -func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - return indexName, columnName -} - -func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} - -// JSON type to support easy handling of JSON data in character table fields -// using golang json.RawMessage for deferred decoding/encoding -type JSON struct { - json.RawMessage -} - -// Value get value of JSON -func (j JSON) Value() (driver.Value, error) { - if len(j.RawMessage) == 0 { - return nil, nil - } - return j.MarshalJSON() -} - -// Scan scan value into JSON -func (j *JSON) Scan(value interface{}) error { - str, ok := value.(string) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) - } - bytes := []byte(str) - return json.Unmarshal(bytes, j) -} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go deleted file mode 100644 index 9deba48a..00000000 --- a/dialects/mysql/mysql.go +++ /dev/null @@ -1,3 +0,0 @@ -package mysql - -import _ "github.com/go-sql-driver/mysql" diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go deleted file mode 100644 index e6c088b1..00000000 --- a/dialects/postgres/postgres.go +++ /dev/null @@ -1,81 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - "encoding/json" - "errors" - "fmt" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} - -// Jsonb Postgresql's JSONB data type -type Jsonb struct { - json.RawMessage -} - -// Value get value of Jsonb -func (j Jsonb) Value() (driver.Value, error) { - if len(j.RawMessage) == 0 { - return nil, nil - } - return j.MarshalJSON() -} - -// Scan scan value into Jsonb -func (j *Jsonb) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) - } - - return json.Unmarshal(bytes, j) -} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go deleted file mode 100644 index 069ad3a9..00000000 --- a/dialects/sqlite/sqlite.go +++ /dev/null @@ -1,3 +0,0 @@ -package sqlite - -import _ "github.com/mattn/go-sqlite3" diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 79bf5fc3..00000000 --- a/docker-compose.yml +++ /dev/null @@ -1,30 +0,0 @@ -version: '3' - -services: - mysql: - image: 'mysql:latest' - ports: - - 9910:3306 - environment: - - MYSQL_DATABASE=gorm - - MYSQL_USER=gorm - - MYSQL_PASSWORD=gorm - - MYSQL_RANDOM_ROOT_PASSWORD="yes" - postgres: - image: 'postgres:latest' - ports: - - 9920:5432 - environment: - - POSTGRES_USER=gorm - - POSTGRES_DB=gorm - - POSTGRES_PASSWORD=gorm - mssql: - image: 'mcmoe/mssqldocker:latest' - ports: - - 9930:1433 - environment: - - ACCEPT_EULA=Y - - SA_PASSWORD=LoremIpsum86 - - MSSQL_DB=gorm - - MSSQL_USER=gorm - - MSSQL_PASSWORD=LoremIpsum86 diff --git a/embedded_struct_test.go b/embedded_struct_test.go deleted file mode 100644 index 5f8ece57..00000000 --- a/embedded_struct_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package gorm_test - -import "testing" - -type BasePost struct { - Id int64 - Title string - URL string -} - -type Author struct { - ID string - Name string - Email string -} - -type HNPost struct { - BasePost - Author `gorm:"embedded_prefix:user_"` // Embedded struct - Upvotes int32 -} - -type EngadgetPost struct { - BasePost BasePost `gorm:"embedded"` - Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct - ImageUrl string -} - -func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { - dialect := DB.NewScope(&EngadgetPost{}).Dialect() - engadgetPostScope := DB.NewScope(&EngadgetPost{}) - if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { - t.Errorf("should has prefix for embedded columns") - } - - if len(engadgetPostScope.PrimaryFields()) != 1 { - t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) - } - - hnScope := DB.NewScope(&HNPost{}) - if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { - t.Errorf("should has prefix for embedded columns") - } -} - -func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) - DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) - var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if news.Title != "hn_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) - var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if egNews.BasePost.Title != "engadget_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - if DB.NewScope(&HNPost{}).PrimaryField() == nil { - t.Errorf("primary key with embedded struct should works") - } - - for _, field := range DB.NewScope(&HNPost{}).Fields() { - if field.Name == "BasePost" { - t.Errorf("scope Fields should not contain embedded struct") - } - } -} - -func TestEmbeddedPointerTypeStruct(t *testing.T) { - type HNPost struct { - *BasePost - Upvotes int32 - } - - DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) - - var hnPost HNPost - if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { - t.Errorf("No error should happen when find embedded pointer type, but got %v", err) - } - - if hnPost.Title != "embedded_pointer_type" { - t.Errorf("Should find correct value for embedded pointer type") - } -} diff --git a/errors.go b/errors.go deleted file mode 100644 index d5ef8d57..00000000 --- a/errors.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm - -import ( - "errors" - "strings" -) - -var ( - // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL occurs when you attempt a query with invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - -// Errors contains all happened errors -type Errors []error - -// IsRecordNotFoundError returns true if error contains a RecordNotFound error -func IsRecordNotFoundError(err error) bool { - if errs, ok := err.(Errors); ok { - for _, err := range errs { - if err == ErrRecordNotFound { - return true - } - } - } - return err == ErrRecordNotFound -} - -// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) -func (errs Errors) GetErrors() []error { - return errs -} - -// Add adds an error to a given slice of errors -func (errs Errors) Add(newErrors ...error) Errors { - for _, err := range newErrors { - if err == nil { - continue - } - - if errors, ok := err.(Errors); ok { - errs = errs.Add(errors...) - } else { - ok = true - for _, e := range errs { - if err == e { - ok = false - } - } - if ok { - errs = append(errs, err) - } - } - } - return errs -} - -// Error takes a slice of all errors that have occurred and returns it as a formatted string -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/errors_test.go b/errors_test.go deleted file mode 100644 index 9a428dec..00000000 --- a/errors_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package gorm_test - -import ( - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestErrorsCanBeUsedOutsideGorm(t *testing.T) { - errs := []error{errors.New("First"), errors.New("Second")} - - gErrs := gorm.Errors(errs) - gErrs = gErrs.Add(errors.New("Third")) - gErrs = gErrs.Add(gErrs) - - if gErrs.Error() != "First; Second; Third" { - t.Fatalf("Gave wrong error, got %s", gErrs.Error()) - } -} diff --git a/field.go b/field.go deleted file mode 100644 index acd06e20..00000000 --- a/field.go +++ /dev/null @@ -1,66 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - v := reflectValue.Interface() - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = scanner.Scan(v) - } - } else { - err = scanner.Scan(v) - } - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/field_test.go b/field_test.go deleted file mode 100644 index 715661f0..00000000 --- a/field_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/hex" - "fmt" - "testing" - - "github.com/jinzhu/gorm" -) - -type CalculateField struct { - gorm.Model - Name string - Children []CalculateFieldChild - Category CalculateFieldCategory - EmbeddedField -} - -type EmbeddedField struct { - EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` -} - -type CalculateFieldChild struct { - gorm.Model - CalculateFieldID uint - Name string -} - -type CalculateFieldCategory struct { - gorm.Model - CalculateFieldID uint - Name string -} - -func TestCalculateField(t *testing.T) { - var field CalculateField - var scope = DB.NewScope(&field) - if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("embedded_name"); !ok { - t.Errorf("should find embedded field") - } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { - t.Errorf("should find embedded field's tag settings") - } -} - -type UUID [16]byte - -type NullUUID struct { - UUID - Valid bool -} - -func FromString(input string) (u UUID) { - src := []byte(input) - return FromBytes(src) -} - -func FromBytes(src []byte) (u UUID) { - dst := u[:] - hex.Decode(dst[0:4], src[0:8]) - hex.Decode(dst[4:6], src[9:13]) - hex.Decode(dst[6:8], src[14:18]) - hex.Decode(dst[8:10], src[19:23]) - hex.Decode(dst[10:], src[24:]) - return -} - -func (u UUID) String() string { - buf := make([]byte, 36) - src := u[:] - hex.Encode(buf[0:8], src[0:4]) - buf[8] = '-' - hex.Encode(buf[9:13], src[4:6]) - buf[13] = '-' - hex.Encode(buf[14:18], src[6:8]) - buf[18] = '-' - hex.Encode(buf[19:23], src[8:10]) - buf[23] = '-' - hex.Encode(buf[24:], src[10:]) - return string(buf) -} - -func (u UUID) Value() (driver.Value, error) { - return u.String(), nil -} - -func (u *UUID) Scan(src interface{}) error { - switch src := src.(type) { - case UUID: // support gorm convert from UUID to NullUUID - *u = src - return nil - case []byte: - *u = FromBytes(src) - return nil - case string: - *u = FromString(src) - return nil - } - return fmt.Errorf("uuid: cannot convert %T to UUID", src) -} - -func (u *NullUUID) Scan(src interface{}) error { - u.Valid = true - return u.UUID.Scan(src) -} - -func TestFieldSet(t *testing.T) { - type TestFieldSetNullUUID struct { - NullUUID NullUUID - } - scope := DB.NewScope(&TestFieldSetNullUUID{}) - field := scope.Fields()[0] - err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00")) - if err != nil { - t.Fatal(err) - } - if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok { - t.Fatal() - } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { - t.Fatal(id) - } -} diff --git a/go.mod b/go.mod index 91ff3cb8..0b3e3065 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1 @@ module github.com/jinzhu/gorm - -go 1.12 - -require ( - github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-sql-driver/mysql v1.5.0 - github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.0.1 - github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v2.0.1+incompatible - golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect -) diff --git a/go.sum b/go.sum deleted file mode 100644 index e09a0352..00000000 --- a/go.sum +++ /dev/null @@ -1,25 +0,0 @@ -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= -github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= -github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/interface.go b/interface.go deleted file mode 100644 index fe649231..00000000 --- a/interface.go +++ /dev/null @@ -1,24 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" -) - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/join_table_handler.go b/join_table_handler.go deleted file mode 100644 index a036d46d..00000000 --- a/join_table_handler.go +++ /dev/null @@ -1,211 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - s.Source.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - s.Destination.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return DefaultTableNameHandler(db, s.TableName) -} - -func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - for _, joinTableSource := range joinTableSources { - if joinTableSource.ModelType == modelType { - for _, foreignKey := range joinTableSource.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - conditionMap[foreignKey.DBName] = field.Field.Interface() - } - } - break - } - } - } -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - var ( - scope = db.NewScope("") - conditionMap = map[string]interface{}{} - ) - - // Update condition map for source - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - - // Update condition map for destination - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range conditionMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - conditionMap = map[string]interface{}{} - ) - - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) - - for key, value := range conditionMap { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/join_table_test.go b/join_table_test.go deleted file mode 100644 index 6d5f427d..00000000 --- a/join_table_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package gorm_test - -import ( - "fmt" - "strconv" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type Person struct { - Id int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` -} - -type PersonAddress struct { - gorm.JoinTableHandler - PersonID int - AddressID int - DeletedAt *time.Time - CreatedAt time.Time -} - -func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { - foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue())) - associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue())) - if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{ - "person_id": foreignPrimaryKey, - "address_id": associationPrimaryKey, - }).Update(map[string]interface{}{ - "person_id": foreignPrimaryKey, - "address_id": associationPrimaryKey, - "deleted_at": gorm.Expr("NULL"), - }).RowsAffected; result == 0 { - return db.Create(&PersonAddress{ - PersonID: foreignPrimaryKey, - AddressID: associationPrimaryKey, - }).Error - } - - return nil -} - -func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { - return db.Delete(&PersonAddress{}).Error -} - -func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { - table := pa.Table(db) - return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) -} - -func TestJoinTable(t *testing.T) { - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &Person{Name: "person", Addresses: []*Address{address1, address2}} - DB.Save(person) - - DB.Model(person).Association("Addresses").Delete(address1) - - if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { - t.Errorf("Should found one address") - } - - if DB.Model(person).Association("Addresses").Count() != 1 { - t.Errorf("Should found one address") - } - - if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { - t.Errorf("Found two addresses with Unscoped") - } - - if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} - -func TestEmbeddedMany2ManyRelationship(t *testing.T) { - type EmbeddedPerson struct { - ID int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` - } - - type NewPerson struct { - EmbeddedPerson - ExternalID uint - } - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&NewPerson{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}} - if err := DB.Save(person).Error; err != nil { - t.Errorf("no error should return when save embedded many2many relationship, but got %v", err) - } - - if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil { - t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err) - } - - association := DB.Model(person).Association("Addresses") - if count := association.Count(); count != 1 || association.Error != nil { - t.Errorf("Should found one address, but got %v, error is %v", count, association.Error) - } - - if association.Clear(); association.Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 88e167dd..00000000 --- a/logger.go +++ /dev/null @@ -1,141 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "strconv" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} - -var LogFormatter = func(values ...interface{}) (messages []interface{}) { - if len(values) > 1 { - var ( - sql string - formattedValues []string - level = values[0] - currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - ) - - messages = []interface{}{source, currentTime} - - if len(values) == 2 { - //remove the line break - currentTime = currentTime[1:] - //remove the brackets - source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) - - messages = []interface{}{currentTime, source} - } - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - if t.IsZero() { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) - } - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - switch value.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: - formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) - default: - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } - } else { - formattedValues = append(formattedValues, "NULL") - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) - sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") - } - } else { - formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - } - - return -} - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - logger.Println(LogFormatter(values...)...) -} - -type nopLogger struct{} - -func (nopLogger) Print(values ...interface{}) {} diff --git a/main.go b/main.go deleted file mode 100644 index 3db87870..00000000 --- a/main.go +++ /dev/null @@ -1,881 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "sync" - "time" -) - -// DB contains information for current db connection -type DB struct { - sync.RWMutex - Value interface{} - Error error - RowsAffected int64 - - // single db - db SQLCommon - blockGlobalUpdate bool - logMode logModeValue - logger logger - search *search - values sync.Map - - // global db - parent *DB - callbacks *Callback - dialect Dialect - singularTable bool - - // function to be used to override the creating of a new timestamp - nowFuncOverride func() time.Time -} - -type logModeValue int - -const ( - defaultLogMode logModeValue = iota - noLogMode - detailedLogMode -) - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - var ownDbSQL bool - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - ownDbSQL = true - case SQLCommon: - dbSQL = value - ownDbSQL = false - default: - return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) - } - - db = &DB{ - db: dbSQL, - logger: defaultLogger, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil && ownDbSQL { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, ok := s.db.(*sql.DB) - if !ok { - panic("can't support full GORM on currently status, maybe this is a TX instance.") - } - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone(s.logger) - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = detailedLogMode - } else { - s.logMode = noLogMode - } - return s -} - -// SetNowFuncOverride set the function to be used when creating a new timestamp -func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { - s.nowFuncOverride = nowFuncOverride - return s -} - -// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, -// otherwise defaults to the global NowFunc() -func (s *DB) nowFunc() time.Time { - if s.nowFuncOverride != nil { - return s.nowFuncOverride() - } - - return NowFunc() -} - -// BlockGlobalUpdate if true, generates an error on update/delete without where clause. -// This is to prevent eventual error with empty objects updates/deletions -func (s *DB) BlockGlobalUpdate(enable bool) *DB { - s.blockGlobalUpdate = enable - return s -} - -// HasBlockGlobalUpdate return state of block -func (s *DB) HasBlockGlobalUpdate() bool { - return s.blockGlobalUpdate -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - s.parent.Lock() - defer s.parent.Unlock() - s.parent.singularTable = enable -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - scope := &Scope{db: dbClone, Value: value} - if s.search != nil { - scope.Search = s.search.clone() - } else { - scope.Search = &search{} - } - return scope -} - -// QueryExpr returns the query as SqlExpr object -func (s *DB) QueryExpr() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// SubQuery returns the query as sub query -func (s *DB) SubQuery() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Take return a record that match given conditions, the order will depend on the database implementation -func (s *DB) Take(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -//Preloads preloads relations, don`t touch out -func (s *DB) Preloads(out interface{}) *DB { - return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - scope = s.NewScope(result) - clone = scope.db - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -// WARNING when update with struct, GORM will not update fields that with zero value -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().Table(scope.TableName()).FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.NewScope(nil) - generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Transaction start a transaction as a block, -// return error will rollback, otherwise to commit. -func (s *DB) Transaction(fc func(tx *DB) error) (err error) { - panicked := true - tx := s.Begin() - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - tx.Rollback() - } - }() - - err = fc(tx) - - if err == nil { - err = tx.Commit().Error - } - - panicked = false - return -} - -// Begin begins a transaction -func (s *DB) Begin() *DB { - return s.BeginTx(context.Background(), &sql.TxOptions{}) -} - -// BeginTx begins a transaction with options -func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.BeginTx(ctx, opts) - c.db = interface{}(tx).(SQLCommon) - - c.dialect.SetDB(c.db) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - if err := db.Rollback(); err != nil && err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// RollbackUnlessCommitted rollback a transaction if it has not yet been -// committed. -func (s *DB) RollbackUnlessCommitted() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - err := db.Rollback() - // Ignore the error indicating that the transaction has already - // been committed. - if err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") -func (s *DB) RemoveForeignKey(field string, dest string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values.Store(name, value) - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values.Load(name) - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == defaultLogMode { - go s.print("error", fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - dialect: newDialect(s.dialect.GetName(), s.db), - nowFuncOverride: s.nowFuncOverride, - } - - s.values.Range(func(k, v interface{}) bool { - db.values.Store(k, v) - return true - }) - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == detailedLogMode { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == detailedLogMode { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index b51fe413..00000000 --- a/main_test.go +++ /dev/null @@ -1,1444 +0,0 @@ -package gorm_test - -// Run tests -// $ docker-compose up -// $ ./test_all.sh - -import ( - "context" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "os" - "path/filepath" - "reflect" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/erikstmartin/go-testdb" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/now" -) - -var ( - DB *gorm.DB - t1, t2, t3, t4, t5 time.Time -) - -func init() { - var err error - - if DB, err = OpenTestConnection(); err != nil { - panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) - } - - runMigration() -} - -func OpenTestConnection() (db *gorm.DB, err error) { - dbDSN := os.Getenv("GORM_DSN") - switch os.Getenv("GORM_DIALECT") { - case "mysql": - fmt.Println("testing mysql...") - if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" - } - db, err = gorm.Open("mysql", dbDSN) - case "postgres": - fmt.Println("testing postgres...") - if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" - } - db, err = gorm.Open("postgres", dbDSN) - case "mssql": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; - // CREATE DATABASE gorm; - // USE gorm; - // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - fmt.Println("testing mssql...") - if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" - } - db, err = gorm.Open("mssql", dbDSN) - default: - fmt.Println("testing sqlite3...") - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) - } - - // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) - // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if debug := os.Getenv("DEBUG"); debug == "true" { - db.LogMode(true) - } else if debug == "false" { - db.LogMode(false) - } - - db.DB().SetMaxIdleConns(10) - - return -} - -func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { - stringRef := "foo" - testCases := []interface{}{42, time.Now(), &stringRef} - for _, tc := range testCases { - t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { - _, err := gorm.Open("postgresql", tc) - if err == nil { - t.Error("Should got error with invalid database source") - } - if !strings.HasPrefix(err.Error(), "invalid database source:") { - t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) - } - }) - } -} - -func TestStringPrimaryKey(t *testing.T) { - type UUIDStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.DropTable(&UUIDStruct{}) - DB.AutoMigrate(&UUIDStruct{}) - - data := UUIDStruct{ID: "uuid", Name: "hello"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" { - t.Errorf("string primary key should not be populated") - } - - data = UUIDStruct{ID: "uuid", Name: "hello world"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" { - t.Errorf("string primary key should not be populated") - } -} - -func TestExceptionsWithInvalidSql(t *testing.T) { - var columns []string - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - var count1, count2 int64 - DB.Model(&User{}).Count(&count1) - if count1 <= 0 { - t.Errorf("Should find some users") - } - - if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - DB.Model(&User{}).Count(&count2) - if count1 != count2 { - t.Errorf("No user should not be deleted by invalid SQL") - } -} - -func TestSetTable(t *testing.T) { - DB.Create(getPreparedUser("pluck_user1", "pluck_user")) - DB.Create(getPreparedUser("pluck_user2", "pluck_user")) - DB.Create(getPreparedUser("pluck_user3", "pluck_user")) - - if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { - t.Error("No errors should happen if set table for pluck", err) - } - - var users []User - if DB.Table("users").Find(&[]User{}).Error != nil { - t.Errorf("No errors should happen if set table for find") - } - - if DB.Table("invalid_table").Find(&users).Error == nil { - t.Errorf("Should got error when table is set to an invalid table") - } - - DB.Exec("drop table deleted_users;") - if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { - t.Errorf("Create table with specified table") - } - - DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) - - var deletedUsers []User - DB.Table("deleted_users").Find(&deletedUsers) - if len(deletedUsers) != 1 { - t.Errorf("Query from specified table") - } - - var user User - DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser") - - user.Age = 20 - DB.Table("deleted_users").Save(&user) - if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() { - t.Errorf("Failed to found updated user") - } - - DB.Save(getPreparedUser("normal_user", "reset_table")) - DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) - var user1, user2, user3 User - DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) - if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { - t.Errorf("unset specified table with blank string") - } -} - -type Order struct { -} - -type Cart struct { -} - -func (c Cart) TableName() string { - return "shopping_cart" -} - -func TestHasTable(t *testing.T) { - type Foo struct { - Id int - Stuff string - } - DB.DropTable(&Foo{}) - - // Table should not exist at this point, HasTable should return false - if ok := DB.HasTable("foos"); ok { - t.Errorf("Table should not exist, but does") - } - if ok := DB.HasTable(&Foo{}); ok { - t.Errorf("Table should not exist, but does") - } - - // We create the table - if err := DB.CreateTable(&Foo{}).Error; err != nil { - t.Errorf("Table should be created") - } - - // And now it should exits, and HasTable should return true - if ok := DB.HasTable("foos"); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } - if ok := DB.HasTable(&Foo{}); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } -} - -func TestTableName(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - if DB.NewScope(&Order{}).TableName() != "orders" { - t.Errorf("&Order's table name should be orders") - } - - if DB.NewScope([]Order{}).TableName() != "orders" { - t.Errorf("[]Order's table name should be orders") - } - - if DB.NewScope(&[]Order{}).TableName() != "orders" { - t.Errorf("&[]Order's table name should be orders") - } - - DB.SingularTable(true) - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - if DB.NewScope(&Order{}).TableName() != "order" { - t.Errorf("&Order's singular table name should be order") - } - - if DB.NewScope([]Order{}).TableName() != "order" { - t.Errorf("[]Order's singular table name should be order") - } - - if DB.NewScope(&[]Order{}).TableName() != "order" { - t.Errorf("&[]Order's singular table name should be order") - } - - if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { - t.Errorf("&Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(Cart{}).TableName() != "shopping_cart" { - t.Errorf("Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { - t.Errorf("&[]Cart's singular table name should be shopping_cart") - } - - if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { - t.Errorf("[]Cart's singular table name should be shopping_cart") - } - DB.SingularTable(false) -} - -func TestTableNameConcurrently(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - var wg sync.WaitGroup - wg.Add(10) - - for i := 1; i <= 10; i++ { - go func(db *gorm.DB) { - DB.SingularTable(true) - wg.Done() - }(DB) - } - wg.Wait() - - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - DB.SingularTable(false) -} - -func TestNullValues(t *testing.T) { - DB.DropTable(&NullValue{}) - DB.AutoMigrate(&NullValue{}) - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: true}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv NullValue - DB.First(&nv, "name = ?", "hello") - - if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-2", Valid: true}, - Gender: &sql.NullString{String: "F", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv2 NullValue - DB.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-3", Valid: false}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err == nil { - t.Errorf("Can't save because of name can't be null") - } -} - -func TestNullValuesWithFirstOrCreate(t *testing.T) { - var nv1 = NullValue{ - Name: sql.NullString{String: "first_or_create", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - } - - var nv2 NullValue - result := DB.Where(nv1).FirstOrCreate(&nv2) - - if result.RowsAffected != 1 { - t.Errorf("RowsAffected should be 1 after create some record") - } - - if result.Error != nil { - t.Errorf("Should not raise any error, but got %v", result.Error) - } - - if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { - t.Errorf("first or create with nullvalues") - } - - if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Age.Int64 != 18 { - t.Errorf("should update age to 18") - } -} - -func TestTransaction(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - tx.Rollback() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx2 := DB.Begin() - u2 := User{Name: "transcation-2"} - if err := tx2.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx2.Commit() - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } - - tx3 := DB.Begin() - u3 := User{Name: "transcation-3"} - if err := tx3.Save(&u3).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx3.RollbackUnlessCommitted() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx4 := DB.Begin() - u4 := User{Name: "transcation-4"} - if err := tx4.Save(&u4).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx4.Commit() - - tx4.RollbackUnlessCommitted() - - if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil { - t.Errorf("Should be able to find committed record") - } -} - -func assertPanic(t *testing.T, f func()) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - f() -} - -func TestTransactionWithBlock(t *testing.T) { - // rollback - err := DB.Transaction(func(tx *gorm.DB) error { - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - return errors.New("the error message") - }) - - if err.Error() != "the error message" { - t.Errorf("Transaction return error will equal the block returns error") - } - - if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - // commit - DB.Transaction(func(tx *gorm.DB) error { - u2 := User{Name: "transcation-2"} - if err := tx.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - return nil - }) - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } - - // panic will rollback - assertPanic(t, func() { - DB.Transaction(func(tx *gorm.DB) error { - u3 := User{Name: "transcation-3"} - if err := tx.Save(&u3).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { - t.Errorf("Should find saved record") - } - - panic("force panic") - }) - }) - - if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after panic rollback") - } -} - -func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.Commit().Error; err != nil { - t.Errorf("Commit should not raise error") - } - - if err := tx.Rollback().Error; err != nil { - t.Errorf("Rollback should not raise error") - } -} - -func TestTransactionReadonly(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect == "" { - dialect = "sqlite" - } - switch dialect { - case "mssql", "sqlite": - t.Skipf("%s does not support readonly transactions\n", dialect) - } - - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - tx.Commit() - - tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - u = User{Name: "transcation-2"} - if err := tx.Save(&u).Error; err == nil { - t.Errorf("Error should have been raised in a readonly transaction") - } - - tx.Rollback() -} - -func TestRow(t *testing.T) { - user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() - var age int64 - row.Scan(&age) - if age != 10 { - t.Errorf("Scan with Row") - } -} - -func TestRows(t *testing.T) { - user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - count := 0 - for rows.Next() { - var name string - var age int64 - rows.Scan(&name, &age) - count++ - } - - if count != 2 { - t.Errorf("Should found two records") - } -} - -func TestScanRows(t *testing.T) { - user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - type Result struct { - Name string - Age int - } - - var results []Result - for rows.Next() { - var result Result - if err := DB.ScanRows(rows, &result); err != nil { - t.Errorf("should get no error, but got %v", err) - } - results = append(results, result) - } - - if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") - } -} - -func TestScan(t *testing.T) { - user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Age int - } - - var res result - DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) - if res.Name != user3.Name { - t.Errorf("Scan into struct should work") - } - - var doubleAgeRes = &result{} - if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil { - t.Errorf("Scan to pointer of pointer") - } - if doubleAgeRes.Age != res.Age*2 { - t.Errorf("Scan double age as age") - } - - var ress []result - DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Scan into struct map") - } -} - -func TestRaw(t *testing.T) { - user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Email string - } - - var ress []result - DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Raw with scan") - } - - rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() - count := 0 - for rows.Next() { - count++ - } - if count != 1 { - t.Errorf("Raw with Rows should find one record with name 3") - } - - DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { - t.Error("Raw sql to update records") - } -} - -func TestGroup(t *testing.T) { - rows, err := DB.Select("name").Table("users").Group("name").Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - rows.Scan(&name) - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestJoins(t *testing.T) { - var user = User{ - Name: "joins", - CreditCard: CreditCard{Number: "411111111111"}, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users1 []User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) - if len(users1) != 2 { - t.Errorf("should find two users using left join") - } - - var users2 []User - DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) - if len(users2) != 1 { - t.Errorf("should find one users using left join with conditions") - } - - var users3 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) - if len(users3) != 1 { - t.Errorf("should find one users using multiple left join conditions") - } - - var users4 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) - if len(users4) != 0 { - t.Errorf("should find no user when searching with unexisting credit card") - } - - var users5 []User - db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5) - if db5.Error != nil { - t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) - } -} - -type JoinedIds struct { - UserID int64 `gorm:"column:id"` - BillingAddressID int64 `gorm:"column:id"` - EmailID int64 `gorm:"column:id"` -} - -func TestScanIdenticalColumnNames(t *testing.T) { - var user = User{ - Name: "joinsIds", - Email: "joinIds@example.com", - BillingAddress: Address{ - Address1: "One Park Place", - }, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users []JoinedIds - DB.Select("users.id, addresses.id, emails.id").Table("users"). - Joins("left join addresses on users.billing_address_id = addresses.id"). - Joins("left join emails on emails.user_id = users.id"). - Where("name = ?", "joinsIds").Scan(&users) - - if len(users) != 2 { - t.Fatal("should find two rows using left join") - } - - if user.Id != users[0].UserID { - t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) - } - if user.Id != users[1].UserID { - t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) - } - - if user.BillingAddressID.Int64 != users[0].BillingAddressID { - t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) - } - if user.BillingAddressID.Int64 != users[1].BillingAddressID { - t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) - } - - if users[0].EmailID == users[1].EmailID { - t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) - } - - if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { - t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) - } - - if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { - t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) - } -} - -func TestJoinsWithSelect(t *testing.T) { - type result struct { - Name string - Email string - } - - user := User{ - Name: "joins_with_select", - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var results []result - DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) - if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { - t.Errorf("Should find all two emails with Join select") - } -} - -func TestHaving(t *testing.T) { - rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - var total int64 - rows.Scan(&name, &total) - - if name == "2" && total != 1 { - t.Errorf("Should have one user having name 2") - } - if name == "3" && total != 2 { - t.Errorf("Should have two users having name 3") - } - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestQueryBuilderSubselectInWhere(t *testing.T) { - user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128} - DB.Save(&user) - - var users []User - DB.Select("*").Where("name IN (?)", DB. - Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) - - if len(users) != 4 { - t.Errorf("Four users should be found, instead found %d", len(users)) - } - - DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB. - Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) - - if len(users) != 2 { - t.Errorf("Two users should be found, instead found %d", len(users)) - } -} - -func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { - user := User{Name: "subquery_test_user1", Age: 10} - DB.Save(&user) - user = User{Name: "subquery_test_user2", Age: 11} - DB.Save(&user) - user = User{Name: "subquery_test_user3", Age: 12} - DB.Save(&user) - - var count int - err := DB.Raw("select count(*) from (?) tmp", - DB.Table("users"). - Select("name"). - Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}). - Group("name"). - QueryExpr(), - ).Count(&count).Error - - if err != nil { - t.Errorf("Expected to get no errors, but got %v", err) - } - if count != 2 { - t.Errorf("Row count must be 2, instead got %d", count) - } - - err = DB.Raw("select count(*) from (?) tmp", - DB.Table("users"). - Select("name"). - Where("name LIKE ?", "subquery_test%"). - Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). - Group("name"). - QueryExpr(), - ).Count(&count).Error - - if err != nil { - t.Errorf("Expected to get no errors, but got %v", err) - } - if count != 1 { - t.Errorf("Row count must be 1, instead got %d", count) - } -} - -func TestQueryBuilderSubselectInHaving(t *testing.T) { - user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128} - DB.Save(&user) - - var users []User - DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB. - Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users) - - if len(users) != 1 { - t.Errorf("Two user group should be found, instead found %d", len(users)) - } -} - -func DialectHasTzSupport() bool { - // NB: mssql and FoundationDB do not support time zones. - if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" { - return false - } - return true -} - -func TestTimeWithZone(t *testing.T) { - var format = "2006-01-02 15:04:05 -0700" - var times []time.Time - GMT8, _ := time.LoadLocation("Asia/Shanghai") - times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) - times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) - - for index, vtime := range times { - name := "time_with_zone_" + strconv.Itoa(index) - user := User{Name: name, Birthday: &vtime} - - if !DialectHasTzSupport() { - // If our driver dialect doesn't support TZ's, just use UTC for everything here. - utcBirthday := user.Birthday.UTC() - user.Birthday = &utcBirthday - } - - DB.Save(&user) - expectedBirthday := "2013-02-18 17:51:49 +0000" - foundBirthday := user.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - var findUser, findUser2, findUser3 User - DB.First(&findUser, "name = ?", name) - foundBirthday = findUser.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { - t.Errorf("User should be found") - } - - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { - t.Errorf("User should not be found") - } - } -} - -func TestHstore(t *testing.T) { - type Details struct { - Id int64 - Bulk postgres.Hstore - } - - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip() - } - - if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { - fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") - panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) - } - - DB.Exec("drop table details") - - if err := DB.CreateTable(&Details{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" - bulk := map[string]*string{ - "bankAccountId": &bankAccountId, - "phoneNumber": &phoneNumber, - "opinion": &opinion, - } - d := Details{Bulk: bulk} - DB.Save(&d) - - var d2 Details - if err := DB.First(&d2).Error; err != nil { - t.Errorf("Got error when tried to fetch details: %+v", err) - } - - for k := range bulk { - if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") - } - } else { - t.Errorf("Details should be existed") - } - } -} - -func TestSetAndGet(t *testing.T) { - if value, ok := DB.Set("hello", "world").Get("hello"); !ok { - t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } - } - - if _, ok := DB.Get("non_existing"); ok { - t.Errorf("Get non existing key should return error") - } -} - -func TestCompatibilityMode(t *testing.T) { - DB, _ := gorm.Open("testdb", "") - testdb.SetQueryFunc(func(query string) (driver.Rows, error) { - columns := []string{"id", "name", "age"} - result := ` - 1,Tim,20 - 2,Joe,25 - 3,Bob,30 - ` - return testdb.RowsFromCSVString(columns, result), nil - }) - - var users []User - DB.Find(&users) - if (users[0].Name != "Tim") || len(users) != 3 { - t.Errorf("Unexcepted result returned") - } -} - -func TestOpenExistingDB(t *testing.T) { - DB.Save(&User{Name: "jnfeinstein"}) - dialect := os.Getenv("GORM_DIALECT") - - db, err := gorm.Open(dialect, DB.DB()) - if err != nil { - t.Errorf("Should have wrapped the existing DB connection") - } - - var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { - t.Errorf("Should have found existing record") - } -} - -func TestDdlErrors(t *testing.T) { - var err error - - if err = DB.Close(); err != nil { - t.Errorf("Closing DDL test db connection err=%s", err) - } - defer func() { - // Reopen DB connection. - if DB, err = OpenTestConnection(); err != nil { - t.Fatalf("Failed re-opening db connection: %s", err) - } - }() - - if err := DB.Find(&User{}).Error; err == nil { - t.Errorf("Expected operation on closed db to produce an error, but err was nil") - } -} - -func TestOpenWithOneParameter(t *testing.T) { - db, err := gorm.Open("dialect") - if db != nil { - t.Error("Open with one parameter returned non nil for db") - } - if err == nil { - t.Error("Open with one parameter returned err as nil") - } -} - -func TestSaveAssociations(t *testing.T) { - db := DB.New() - deltaAddressCount := 0 - if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { - t.Errorf("failed to fetch address count") - t.FailNow() - } - - placeAddress := &Address{ - Address1: "somewhere on earth", - } - ownerAddress1 := &Address{ - Address1: "near place address", - } - ownerAddress2 := &Address{ - Address1: "address2", - } - db.Create(placeAddress) - - addressCountShouldBe := func(t *testing.T, expectedCount int) { - countFromDB := 0 - t.Helper() - err := db.Model(&Address{}).Count(&countFromDB).Error - if err != nil { - t.Error("failed to fetch address count") - } - if countFromDB != expectedCount { - t.Errorf("address count mismatch: %d", countFromDB) - } - } - addressCountShouldBe(t, deltaAddressCount+1) - - // owner address should be created, place address should be reused - place1 := &Place{ - PlaceAddressID: placeAddress.ID, - PlaceAddress: placeAddress, - OwnerAddress: ownerAddress1, - } - err := db.Create(place1).Error - if err != nil { - t.Errorf("failed to store place: %s", err.Error()) - } - addressCountShouldBe(t, deltaAddressCount+2) - - // owner address should be created again, place address should be reused - place2 := &Place{ - PlaceAddressID: placeAddress.ID, - PlaceAddress: &Address{ - ID: 777, - Address1: "address1", - }, - OwnerAddress: ownerAddress2, - OwnerAddressID: 778, - } - err = db.Create(place2).Error - if err != nil { - t.Errorf("failed to store place: %s", err.Error()) - } - addressCountShouldBe(t, deltaAddressCount+3) - - count := 0 - db.Model(&Place{}).Where(&Place{ - PlaceAddressID: placeAddress.ID, - OwnerAddressID: ownerAddress1.ID, - }).Count(&count) - if count != 1 { - t.Errorf("only one instance of (%d, %d) should be available, found: %d", - placeAddress.ID, ownerAddress1.ID, count) - } - - db.Model(&Place{}).Where(&Place{ - PlaceAddressID: placeAddress.ID, - OwnerAddressID: ownerAddress2.ID, - }).Count(&count) - if count != 1 { - t.Errorf("only one instance of (%d, %d) should be available, found: %d", - placeAddress.ID, ownerAddress2.ID, count) - } - - db.Model(&Place{}).Where(&Place{ - PlaceAddressID: placeAddress.ID, - }).Count(&count) - if count != 2 { - t.Errorf("two instances of (%d) should be available, found: %d", - placeAddress.ID, count) - } -} - -func TestBlockGlobalUpdate(t *testing.T) { - db := DB.New() - db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) - - err := db.Model(&Toy{}).Update("OwnerType", "Human").Error - if err != nil { - t.Error("Unexpected error on global update") - } - - err = db.Delete(&Toy{}).Error - if err != nil { - t.Error("Unexpected error on global delete") - } - - db.BlockGlobalUpdate(true) - - db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) - - err = db.Model(&Toy{}).Update("OwnerType", "Human").Error - if err == nil { - t.Error("Expected error on global update") - } - - err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error - if err != nil { - t.Error("Unxpected error on conditional update") - } - - err = db.Delete(&Toy{}).Error - if err == nil { - t.Error("Expected error on global delete") - } - err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error - if err != nil { - t.Error("Unexpected error on conditional delete") - } -} - -func TestCountWithHaving(t *testing.T) { - db := DB.New() - db.Delete(User{}) - defer db.Delete(User{}) - - DB.Create(getPreparedUser("user1", "pluck_user")) - DB.Create(getPreparedUser("user2", "pluck_user")) - user3 := getPreparedUser("user3", "pluck_user") - user3.Languages = []Language{} - DB.Create(user3) - - var count int - err := db.Model(User{}).Select("users.id"). - Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id"). - Joins("LEFT JOIN languages ON user_languages.language_id = languages.id"). - Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error - - if err != nil { - t.Error("Unexpected error on query count with having") - } - - if count != 2 { - t.Error("Unexpected result on query count with having") - } -} - -func TestPluck(t *testing.T) { - db := DB.New() - db.Delete(User{}) - defer db.Delete(User{}) - - DB.Create(&User{Id: 1, Name: "user1"}) - DB.Create(&User{Id: 2, Name: "user2"}) - DB.Create(&User{Id: 3, Name: "user3"}) - - var ids []int64 - err := db.Model(User{}).Order("id").Pluck("id", &ids).Error - - if err != nil { - t.Error("Unexpected error on pluck") - } - - if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { - t.Error("Unexpected result on pluck") - } - - err = db.Model(User{}).Order("id").Pluck("id", &ids).Error - - if err != nil { - t.Error("Unexpected error on pluck again") - } - - if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { - t.Error("Unexpected result on pluck again") - } -} - -func TestCountWithQueryOption(t *testing.T) { - db := DB.New() - db.Delete(User{}) - defer db.Delete(User{}) - - DB.Create(&User{Name: "user1"}) - DB.Create(&User{Name: "user2"}) - DB.Create(&User{Name: "user3"}) - - var count int - err := db.Model(User{}).Select("users.id"). - Set("gorm:query_option", "WHERE users.name='user2'"). - Count(&count).Error - - if err != nil { - t.Error("Unexpected error on query count with query_option") - } - - if count != 1 { - t.Error("Unexpected result on query count with query_option") - } -} - -func TestQueryHint1(t *testing.T) { - db := DB.New() - - _, err := db.Model(User{}).Raw("select 1").Rows() - - if err != nil { - t.Error("Unexpected error on query count with query_option") - } -} - -func TestQueryHint2(t *testing.T) { - type TestStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.DropTable(&TestStruct{}) - DB.AutoMigrate(&TestStruct{}) - - data := TestStruct{ID: "uuid", Name: "hello"} - if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil { - t.Error("Unexpected error on query count with query_option") - } -} - -func TestFloatColumnPrecision(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" { - t.Skip() - } - - type FloatTest struct { - ID string `gorm:"primary_key"` - FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"` - } - DB.DropTable(&FloatTest{}) - DB.AutoMigrate(&FloatTest{}) - - data := FloatTest{ID: "uuid", FloatValue: 112.57315} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 { - t.Errorf("Float value should not lose precision") - } -} - -func TestWhereUpdates(t *testing.T) { - type OwnerEntity struct { - gorm.Model - OwnerID uint - OwnerType string - } - - type SomeEntity struct { - gorm.Model - Name string - OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"` - } - - DB.DropTable(&SomeEntity{}) - DB.AutoMigrate(&SomeEntity{}) - - a := SomeEntity{Name: "test"} - DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) -} - -func BenchmarkGorm(b *testing.B) { - b.N = 2000 - for x := 0; x < b.N; x++ { - e := strconv.Itoa(x) + "benchmark@example.org" - now := time.Now() - email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} - // Insert - DB.Save(&email) - // Query - DB.First(&EmailWithIdx{}, "email = ?", e) - // Update - DB.Model(&email).UpdateColumn("email", "new-"+e) - // Delete - DB.Delete(&email) - } -} - -func BenchmarkRawSql(b *testing.B) { - DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - DB.SetMaxIdleConns(10) - insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" - querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" - updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" - deleteSql := "DELETE FROM orders WHERE id = $1" - - b.N = 2000 - for x := 0; x < b.N; x++ { - var id int64 - e := strconv.Itoa(x) + "benchmark@example.org" - now := time.Now() - email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} - // Insert - DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) - // Query - rows, _ := DB.Query(querySql, email.Email) - rows.Close() - // Update - DB.Exec(updateSql, "new-"+e, time.Now(), id) - // Delete - DB.Exec(deleteSql, id) - } -} - -func parseTime(str string) *time.Time { - t := now.New(time.Now().UTC()).MustParse(str) - return &t -} diff --git a/migration_test.go b/migration_test.go deleted file mode 100644 index d94ec9ec..00000000 --- a/migration_test.go +++ /dev/null @@ -1,579 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "os" - "reflect" - "strconv" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Email string - Birthday *time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role Role - Password EncryptedData - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` -} - -type NotSoLongTableName struct { - Id int64 - ReallyLongThingID int64 - ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit -} - -type ReallyLongTableNameToTestMySQLNameLengthLimit struct { - Id int64 -} - -type ReallyLongThingThatReferencesShort struct { - Id int64 - ShortID int64 - Short Short -} - -type Short struct { - Id int64 -} - -type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time `sql:"column:deleted_time"` -} - -type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` -} - -type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 -} - -type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` -} - -type Place struct { - Id int64 - PlaceAddressID int - PlaceAddress *Address `gorm:"save_associations:false"` - OwnerAddressID int - OwnerAddress *Address `gorm:"save_associations:true"` -} - -type EncryptedData []byte - -func (data *EncryptedData) Scan(value interface{}) error { - if b, ok := value.([]byte); ok { - if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { - return errors.New("Too short") - } - - *data = b[3:] - return nil - } - - return errors.New("Bytes expected") -} - -func (data EncryptedData) Value() (driver.Value, error) { - if len(data) > 0 && data[0] == 'x' { - //needed to test failures - return nil, errors.New("Should not start with 'x'") - } - - //prepend asterisks - return append([]byte("***"), data...), nil -} - -type Role struct { - Name string `gorm:"size:256"` -} - -func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil -} - -func (role Role) Value() (driver.Value, error) { - return role.Name, nil -} - -func (role Role) IsAdmin() bool { - return role.Name == "admin" -} - -type Num int64 - -func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - n, _ := strconv.Atoi(string(s)) - *i = Num(n) - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil -} - -type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time -} - -type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` -} - -type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category -} - -type Category struct { - gorm.Model - Name string - - Categories []Category - CategoryID *uint -} - -type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post -} - -// Scanner -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} - -type NullTime struct { - Time time.Time - Valid bool -} - -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} - -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - -func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) - - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } -} - -func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { - fmt.Printf("Got error when try to delete table users, %+v\n", err) - } - - for _, table := range []string{"animals", "user_languages"} { - DB.Exec(fmt.Sprintf("drop table %v;", table)) - } - - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} - for _, value := range values { - DB.DropTable(value) - } - if err := DB.AutoMigrate(values...).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } -} - -func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email should have index idx_email_email") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email's index idx_email_email should be deleted") - } - - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { - t.Errorf("Should get to create duplicate record when having unique index") - } - - var user = User{Name: "sample_user"} - DB.Save(&user) - if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { - t.Errorf("Should get no error when append two emails for user") - } - - if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { - t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { - t.Errorf("Should be able to create duplicated emails after remove unique index") - } -} - -type EmailWithIdx struct { - Id int64 - UserId int64 - Email string `sql:"index:idx_email_agent"` - UserAgent string `sql:"index:idx_email_agent"` - RegisteredAt *time.Time `sql:"unique_index"` - CreatedAt time.Time - UpdatedAt time.Time -} - -func TestAutoMigration(t *testing.T) { - DB.AutoMigrate(&Address{}) - DB.DropTable(&EmailWithIdx{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - now := time.Now() - DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) - - scope := DB.NewScope(&EmailWithIdx{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { - t.Errorf("Failed to create index") - } - - var bigemail EmailWithIdx - DB.First(&bigemail, "user_agent = ?", "pc") - if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { - t.Error("Big Emails should be saved and fetched correctly") - } -} - -func TestCreateAndAutomigrateTransaction(t *testing.T) { - tx := DB.Begin() - - func() { - type Bar struct { - ID uint - } - DB.DropTableIfExists(&Bar{}) - - if ok := DB.HasTable("bars"); ok { - t.Errorf("Table should not exist, but does") - } - - if ok := tx.HasTable("bars"); ok { - t.Errorf("Table should not exist, but does") - } - }() - - func() { - type Bar struct { - Name string - } - err := tx.CreateTable(&Bar{}).Error - - if err != nil { - t.Errorf("Should have been able to create the table, but couldn't: %s", err) - } - - if ok := tx.HasTable(&Bar{}); !ok { - t.Errorf("The transaction should be able to see the table") - } - }() - - func() { - type Bar struct { - Stuff string - } - - err := tx.AutoMigrate(&Bar{}).Error - if err != nil { - t.Errorf("Should have been able to alter the table, but couldn't") - } - }() - - tx.Rollback() -} - -type MultipleIndexes struct { - ID int64 - UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` - Name string `sql:"unique_index:uix_multipleindexes_user_name"` - Email string `sql:"unique_index:,uix_multipleindexes_user_email"` - Other string `sql:"index:,idx_multipleindexes_user_other"` -} - -func TestMultipleIndexes(t *testing.T) { - if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { - fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) - } - - DB.AutoMigrate(&MultipleIndexes{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) - - scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { - t.Errorf("Failed to create index") - } - - var mutipleIndexes MultipleIndexes - DB.First(&mutipleIndexes, "name = ?", "jinzhu") - if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { - t.Error("MutipleIndexes should be saved and fetched correctly") - } - - // Check unique constraints - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } -} - -func TestModifyColumnType(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { - t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") - } - - type ModifyColumnType struct { - gorm.Model - Name1 string `gorm:"length:100"` - Name2 string `gorm:"length:200"` - } - DB.DropTable(&ModifyColumnType{}) - DB.CreateTable(&ModifyColumnType{}) - - name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") - name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) - - if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { - t.Errorf("No error should happen when ModifyColumn, but got %v", err) - } -} - -func TestIndexWithPrefixLength(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { - t.Skip("Skipping this because only mysql support setting an index prefix length") - } - - type IndexWithPrefix struct { - gorm.Model - Name string - Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - type IndexesWithPrefix struct { - gorm.Model - Name string - Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - type IndexesWithPrefixAndWithoutPrefix struct { - gorm.Model - Name string `gorm:"index:idx_index_with_prefixes_length"` - Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} - for _, table := range tables { - scope := DB.NewScope(table) - tableName := scope.TableName() - t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { - if err := DB.DropTableIfExists(table).Error; err != nil { - t.Errorf("Failed to drop %s table: %v", tableName, err) - } - if err := DB.CreateTable(table).Error; err != nil { - t.Errorf("Failed to create %s table: %v", tableName, err) - } - if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { - t.Errorf("Failed to create %s table index:", tableName) - } - }) - } -} diff --git a/model.go b/model.go deleted file mode 100644 index f37ff7ea..00000000 --- a/model.go +++ /dev/null @@ -1,14 +0,0 @@ -package gorm - -import "time" - -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} diff --git a/model_struct.go b/model_struct.go deleted file mode 100644 index d9e2e90f..00000000 --- a/model_struct.go +++ /dev/null @@ -1,671 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -// lock for mutating global cached model metadata -var structsLock sync.Mutex - -// global cache of model metadata -var modelStructsMap sync.Map - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - - defaultTableName string - l sync.Mutex -} - -// TableName returns model's table name -func (s *ModelStruct) TableName(db *DB) string { - s.l.Lock() - defer s.l.Unlock() - - if s.defaultTableName == "" && db != nil && s.ModelType != nil { - // Set default table name - if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { - s.defaultTableName = tabler.TableName() - } else { - tableName := ToTableName(s.ModelType.Name()) - db.parent.RLock() - if db == nil || (db.parent != nil && !db.parent.singularTable) { - tableName = inflection.Plural(tableName) - } - db.parent.RUnlock() - s.defaultTableName = tableName - } - } - - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship - - tagSettingsLock sync.RWMutex -} - -// TagSettingsSet Sets a tag in the tag settings map -func (sf *StructField) TagSettingsSet(key, val string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - sf.TagSettings[key] = val -} - -// TagSettingsGet returns a tag from the tag settings -func (sf *StructField) TagSettingsGet(key string) (string, bool) { - sf.tagSettingsLock.RLock() - defer sf.tagSettingsLock.RUnlock() - val, ok := sf.TagSettings[key] - return val, ok -} - -// TagSettingsDelete deletes a tag -func (sf *StructField) TagSettingsDelete(key string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - delete(sf.TagSettings, key) -} - -func (sf *StructField) clone() *StructField { - clone := &StructField{ - DBName: sf.DBName, - Name: sf.Name, - Names: sf.Names, - IsPrimaryKey: sf.IsPrimaryKey, - IsNormal: sf.IsNormal, - IsIgnored: sf.IsIgnored, - IsScanner: sf.IsScanner, - HasDefaultValue: sf.HasDefaultValue, - Tag: sf.Tag, - TagSettings: map[string]string{}, - Struct: sf.Struct, - IsForeignKey: sf.IsForeignKey, - } - - if sf.Relationship != nil { - relationship := *sf.Relationship - clone.Relationship = &relationship - } - - // copy the struct field tagSettings, they should be read-locked while they are copied - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - for key, value := range sf.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - PolymorphicValue string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - isSingularTable := false - if scope.db != nil && scope.db.parent != nil { - scope.db.parent.RLock() - isSingularTable = scope.db.parent.singularTable - scope.db.parent.RUnlock() - } - - hashKey := struct { - singularTable bool - reflectType reflect.Type - }{isSingularTable, reflectType} - if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { - return value.(*ModelStruct) - } - - modelStruct.ModelType = reflectType - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettingsGet("-"); ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettingsGet(key); !ok { - field.TagSettingsSet(key, value) - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { - subField.DBName = prefix + subField.DBName - } - - if subField.IsPrimaryKey { - if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } - } - - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } - } - - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - foreignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - relationship.Kind = "many_to_many" - - { // Foreign Keys for Source - joinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { - joinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - - // setup join table foreign keys for source - if len(joinTableDBNames) > idx { - // if defined join table's foreign key - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) - } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) - } - } - } - } - - { // Foreign Keys for Association (Destination) - associationJoinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { - associationJoinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for idx, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // setup join table foreign keys for association - if len(associationJoinTableDBNames) > idx { - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) - } else { - // join table foreign keys for association - joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - tagForeignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettingsGet("COLUMN"); ok { - field.DBName = value - } else { - field.DBName = ToColumnName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Store(hashKey, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - if str == "" { - continue - } - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/model_struct_test.go b/model_struct_test.go deleted file mode 100644 index 2ae419a0..00000000 --- a/model_struct_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package gorm_test - -import ( - "sync" - "testing" - - "github.com/jinzhu/gorm" -) - -type ModelA struct { - gorm.Model - Name string - - ModelCs []ModelC `gorm:"foreignkey:OtherAID"` -} - -type ModelB struct { - gorm.Model - Name string - - ModelCs []ModelC `gorm:"foreignkey:OtherBID"` -} - -type ModelC struct { - gorm.Model - Name string - - OtherAID uint64 - OtherA *ModelA `gorm:"foreignkey:OtherAID"` - OtherBID uint64 - OtherB *ModelB `gorm:"foreignkey:OtherBID"` -} - -// This test will try to cause a race condition on the model's foreignkey metadata -func TestModelStructRaceSameModel(t *testing.T) { - // use a WaitGroup to execute as much in-sync as possible - // it's more likely to hit a race condition than without - n := 32 - start := sync.WaitGroup{} - start.Add(n) - - // use another WaitGroup to know when the test is done - done := sync.WaitGroup{} - done.Add(n) - - for i := 0; i < n; i++ { - go func() { - start.Wait() - - // call GetStructFields, this had a race condition before we fixed it - DB.NewScope(&ModelA{}).GetStructFields() - - done.Done() - }() - - start.Done() - } - - done.Wait() -} - -// This test will try to cause a race condition on the model's foreignkey metadata -func TestModelStructRaceDifferentModel(t *testing.T) { - // use a WaitGroup to execute as much in-sync as possible - // it's more likely to hit a race condition than without - n := 32 - start := sync.WaitGroup{} - start.Add(n) - - // use another WaitGroup to know when the test is done - done := sync.WaitGroup{} - done.Add(n) - - for i := 0; i < n; i++ { - i := i - go func() { - start.Wait() - - // call GetStructFields, this had a race condition before we fixed it - if i%2 == 0 { - DB.NewScope(&ModelA{}).GetStructFields() - } else { - DB.NewScope(&ModelB{}).GetStructFields() - } - - done.Done() - }() - - start.Done() - } - - done.Wait() -} diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go deleted file mode 100644 index 32a14772..00000000 --- a/multi_primary_keys_test.go +++ /dev/null @@ -1,381 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "sort" - "testing" -) - -type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` - SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` - LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` -} - -type Tag struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` -} - -func compareTags(tags []Tag, contents []string) bool { - var tagContents []string - for _, tag := range tags { - tagContents = append(tagContents, tag.Value) - } - sort.Strings(tagContents) - sort.Strings(contents) - return reflect.DeepEqual(tagContents, contents) -} - -func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - Tags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - - DB.Save(&blog) - if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) - if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("Tags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "Tags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("Tags").Find(&blog1) - if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog).Association("Tags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "Tags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("Tags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("Tags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "Tags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("Tags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog).Association("Tags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "Tags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog).Association("Tags").Clear() - if DB.Model(&blog).Association("Tags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("shared_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - SharedTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) - if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("SharedTags").Find(&blog1) - if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("SharedTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - DB.Model(&blog2).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("SharedTags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "SharedTags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog2).Association("SharedTags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "SharedTags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog2).Association("SharedTags").Clear() - if DB.Model(&blog).Association("SharedTags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("locale_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - LocaleTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) - if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog should has 0 tags after ZH Blog Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if len(tags) != 0 { - t.Errorf("Should find 0 tags with Related for EN Blog") - } - - var blog1 Blog - DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) - if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("LocaleTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related for EN Blog") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag4"}) { - t.Errorf("Should find 1 tags with Related for EN Blog") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) - - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - var blog11 Blog - DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) - if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - DB.Model(&blog2).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - var blog21 Blog - DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) - if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { - t.Errorf("EN Blog's tags should be changed after Replace") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Replace") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after Replace") - } - - // Delete - DB.Model(&blog).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") - } - - DB.Model(&blog2).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { - t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") - } - - // Clear - DB.Model(&blog2).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") - } - - DB.Model(&blog).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 0 { - t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared") - } - } -} diff --git a/naming.go b/naming.go deleted file mode 100644 index 6b0a4fdd..00000000 --- a/naming.go +++ /dev/null @@ -1,124 +0,0 @@ -package gorm - -import ( - "bytes" - "strings" -) - -// Namer is a function type which is given a string and return a string -type Namer func(string) string - -// NamingStrategy represents naming strategies -type NamingStrategy struct { - DB Namer - Table Namer - Column Namer -} - -// TheNamingStrategy is being initialized with defaultNamingStrategy -var TheNamingStrategy = &NamingStrategy{ - DB: defaultNamer, - Table: defaultNamer, - Column: defaultNamer, -} - -// AddNamingStrategy sets the naming strategy -func AddNamingStrategy(ns *NamingStrategy) { - if ns.DB == nil { - ns.DB = defaultNamer - } - if ns.Table == nil { - ns.Table = defaultNamer - } - if ns.Column == nil { - ns.Column = defaultNamer - } - TheNamingStrategy = ns -} - -// DBName alters the given name by DB -func (ns *NamingStrategy) DBName(name string) string { - return ns.DB(name) -} - -// TableName alters the given name by Table -func (ns *NamingStrategy) TableName(name string) string { - return ns.Table(name) -} - -// ColumnName alters the given name by Column -func (ns *NamingStrategy) ColumnName(name string) string { - return ns.Column(name) -} - -// ToDBName convert string to db name -func ToDBName(name string) string { - return TheNamingStrategy.DBName(name) -} - -// ToTableName convert string to table name -func ToTableName(name string) string { - return TheNamingStrategy.TableName(name) -} - -// ToColumnName convert string to db name -func ToColumnName(name string) string { - return TheNamingStrategy.ColumnName(name) -} - -var smap = newSafeMap() - -func defaultNamer(name string) string { - const ( - lower = false - upper = true - ) - - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber bool - ) - - for i, v := range value[:len(value)-1] { - nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} diff --git a/naming_test.go b/naming_test.go deleted file mode 100644 index 0c6f7713..00000000 --- a/naming_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestTheNamingStrategy(t *testing.T) { - - cases := []struct { - name string - namer gorm.Namer - expected string - }{ - {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, - {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, - {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - result := c.namer(c.name) - if result != c.expected { - t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) - } - }) - } - -} - -func TestNamingStrategy(t *testing.T) { - - dbNameNS := func(name string) string { - return "db_" + name - } - tableNameNS := func(name string) string { - return "tbl_" + name - } - columnNameNS := func(name string) string { - return "col_" + name - } - - ns := &gorm.NamingStrategy{ - DB: dbNameNS, - Table: tableNameNS, - Column: columnNameNS, - } - - cases := []struct { - name string - namer gorm.Namer - expected string - }{ - {name: "auth", expected: "db_auth", namer: ns.DB}, - {name: "user", expected: "tbl_user", namer: ns.Table}, - {name: "password", expected: "col_password", namer: ns.Column}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - result := c.namer(c.name) - if result != c.expected { - t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) - } - }) - } - -} diff --git a/pointer_test.go b/pointer_test.go deleted file mode 100644 index 2a68a5ab..00000000 --- a/pointer_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package gorm_test - -import "testing" - -type PointerStruct struct { - ID int64 - Name *string - Num *int -} - -type NormalStruct struct { - ID int64 - Name string - Num int -} - -func TestPointerFields(t *testing.T) { - DB.DropTable(&PointerStruct{}) - DB.AutoMigrate(&PointerStruct{}) - var name = "pointer struct 1" - var num = 100 - pointerStruct := PointerStruct{Name: &name, Num: &num} - if DB.Create(&pointerStruct).Error != nil { - t.Errorf("Failed to save pointer struct") - } - - var pointerStructResult PointerStruct - if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { - t.Errorf("Failed to query saved pointer struct") - } - - var tableName = DB.NewScope(&PointerStruct{}).TableName() - - var normalStruct NormalStruct - DB.Table(tableName).First(&normalStruct) - if normalStruct.Name != name || normalStruct.Num != num { - t.Errorf("Failed to query saved Normal struct") - } - - var nilPointerStruct = PointerStruct{} - if err := DB.Create(&nilPointerStruct).Error; err != nil { - t.Error("Failed to save nil pointer struct", err) - } - - var pointerStruct2 PointerStruct - if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var normalStruct2 NormalStruct - if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var partialNilPointerStruct1 = PointerStruct{Num: &num} - if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct3 PointerStruct - if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct3 NormalStruct - if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { - t.Error("Failed to query saved partial pointer struct", err) - } - - var partialNilPointerStruct2 = PointerStruct{Name: &name} - if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct4 PointerStruct - if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct4 NormalStruct - if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { - t.Error("Failed to query saved partial pointer struct", err) - } -} diff --git a/polymorphic_test.go b/polymorphic_test.go deleted file mode 100644 index d1ecfbbb..00000000 --- a/polymorphic_test.go +++ /dev/null @@ -1,366 +0,0 @@ -package gorm_test - -import ( - "reflect" - "sort" - "testing" -) - -type Cat struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` -} - -type Dog struct { - Id int - Name string - Toys []Toy `gorm:"polymorphic:Owner;"` -} - -type Hamster struct { - Id int - Name string - PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` - OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` -} - -type Toy struct { - Id int - Name string - OwnerId int - OwnerType string -} - -var compareToys = func(toys []Toy, contents []string) bool { - var toyContents []string - for _, toy := range toys { - toyContents = append(toyContents, toy.Name) - } - sort.Strings(toyContents) - sort.Strings(contents) - return reflect.DeepEqual(toyContents, contents) -} - -func TestPolymorphic(t *testing.T) { - cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} - dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} - DB.Save(&cat).Save(&dog) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Dog's toys count should be 2") - } - - // Query - var catToys []Toy - if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(catToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if catToys[0].Name != cat.Toy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - var dogToys []Toy - if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { - t.Errorf("Did not find any polymorphic has many associations") - } else if len(dogToys) != len(dog.Toys) { - t.Errorf("Should have found all polymorphic has many associations") - } - - var catToy Toy - DB.Model(&cat).Association("Toy").Find(&catToy) - if catToy.Name != cat.Toy.Name { - t.Errorf("Should find has one polymorphic association") - } - - var dogToys1 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys1) - if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { - t.Errorf("Should find has many polymorphic association") - } - - // Append - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - var catToy2 Toy - DB.Model(&cat).Association("Toy").Find(&catToy2) - if catToy2.Name != "cat toy 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Should return two polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 3", - }) - - var dogToys2 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys2) - if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { - t.Errorf("Dog's toys should be updated with Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Replace - DB.Model(&cat).Association("Toy").Replace(&Toy{ - Name: "cat toy 3", - }) - - var catToy3 Toy - DB.Model(&cat).Association("Toy").Find(&catToy3) - if catToy3.Name != "cat toy 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Replace(&Toy{ - Name: "dog toy 4", - }, []Toy{ - {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, - }) - - var dogToys3 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys3) - if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { - t.Errorf("Dog's toys should be updated with Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Delete - DB.Model(&cat).Association("Toy").Delete(&catToy2) - - var catToy4 Toy - DB.Model(&cat).Association("Toy").Find(&catToy4) - if catToy4.Name != "cat toy 3" { - t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should be 4") - } - - DB.Model(&cat).Association("Toy").Delete(&catToy3) - - if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { - t.Errorf("Toy should be deleted with Delete") - } - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys count should be 0 after Delete") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete cat's toy") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys2) - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete unrelated toys") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys3) - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys count should be deleted with Delete") - } - - // Clear - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys should be added with Append") - } - - DB.Model(&cat).Association("Toy").Clear() - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys should be cleared with Clear") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 8", - }) - - if DB.Model(&dog).Association("Toys").Count() != 1 { - t.Errorf("Dog's toys should be added with Append") - } - - DB.Model(&dog).Association("Toys").Clear() - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys should be cleared with Clear") - } -} - -func TestNamedPolymorphic(t *testing.T) { - hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} - DB.Save(&hamster) - - hamster2 := Hamster{} - DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) - if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { - t.Errorf("Hamster's preferred toy couldn't be preloaded") - } - if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { - t.Errorf("Hamster's other toy couldn't be preloaded") - } - - // clear to omit Toy.Id in count - hamster2.PreferredToy = Toy{} - hamster2.OtherToy = Toy{} - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's preferred toy count should be 1") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's other toy count should be 1") - } - - // Query - var hamsterToys []Toy - if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(hamsterToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if hamsterToys[0].Name != hamster.PreferredToy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(hamsterToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if hamsterToys[0].Name != hamster.OtherToy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - hamsterToy := Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != hamster.PreferredToy.Name { - t.Errorf("Should find has one polymorphic association") - } - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != hamster.OtherToy.Name { - t.Errorf("Should find has one polymorphic association") - } - - // Append - DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ - Name: "bike 2", - }) - DB.Model(&hamster).Association("OtherToy").Append(&Toy{ - Name: "treadmill 2", - }) - - hamsterToy = Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != "bike 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != "treadmill 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's toys count should be 1 after Append") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's toys count should be 1 after Append") - } - - // Replace - DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ - Name: "bike 3", - }) - DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ - Name: "treadmill 3", - }) - - hamsterToy = Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != "bike 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != "treadmill 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("hamster's toys count should be 1 after Replace") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("hamster's toys count should be 1 after Replace") - } - - // Clear - DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ - Name: "bike 2", - }) - DB.Model(&hamster).Association("OtherToy").Append(&Toy{ - Name: "treadmill 2", - }) - - if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's toys should be added with Append") - } - if DB.Model(&hamster).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's toys should be added with Append") - } - - DB.Model(&hamster).Association("PreferredToy").Clear() - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { - t.Errorf("Hamster's preferred toy should be cleared with Clear") - } - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's other toy should be still available") - } - - DB.Model(&hamster).Association("OtherToy").Clear() - if DB.Model(&hamster).Association("OtherToy").Count() != 0 { - t.Errorf("Hamster's other toy should be cleared with Clear") - } -} diff --git a/preload_test.go b/preload_test.go deleted file mode 100644 index dd29fb5e..00000000 --- a/preload_test.go +++ /dev/null @@ -1,1701 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "encoding/json" - "os" - "reflect" - "testing" - - "github.com/jinzhu/gorm" -) - -func getPreloadUser(name string) *User { - return getPreparedUser(name, "Preload") -} - -func checkUserHasPreloadData(user User, t *testing.T) { - u := getPreloadUser(user.Name) - if user.BillingAddress.Address1 != u.BillingAddress.Address1 { - t.Error("Failed to preload user's BillingAddress") - } - - if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { - t.Error("Failed to preload user's ShippingAddress") - } - - if user.CreditCard.Number != u.CreditCard.Number { - t.Error("Failed to preload user's CreditCard") - } - - if user.Company.Name != u.Company.Name { - t.Error("Failed to preload user's Company") - } - - if len(user.Emails) != len(u.Emails) { - t.Error("Failed to preload user's Emails") - } else { - var found int - for _, e1 := range u.Emails { - for _, e2 := range user.Emails { - if e1.Email == e2.Email { - found++ - break - } - } - } - if found != len(u.Emails) { - t.Error("Failed to preload user's email details") - } - } -} - -func TestPreload(t *testing.T) { - user1 := getPreloadUser("user1") - DB.Save(user1) - - preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("user2") - DB.Save(user2) - - user3 := getPreloadUser("user3") - DB.Save(user3) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } - - var users3 []*User - preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) - - for _, user := range users3 { - if user.Name == user3.Name { - if len(user.Emails) != 1 { - t.Errorf("should only preload one emails for user3 when with condition") - } - } else if len(user.Emails) != 0 { - t.Errorf("should not preload any emails for other users when with condition") - } else if user.Emails == nil { - t.Errorf("should return an empty slice to indicate zero results") - } - } -} - -func TestAutoPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) - - preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("auto_user2") - DB.Save(user2) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } -} - -func TestAutoPreloadFalseDoesntPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) - - preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) - - if user.BillingAddress.Address1 != "" { - t.Error("AutoPreload was set to fasle, but still fetched data") - } - - user2 := getPreloadUser("auto_user2") - DB.Save(user2) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - if user.BillingAddress.Address1 != "" { - t.Error("AutoPreload was set to fasle, but still fetched data") - } - } -} - -func TestNestedPreload1(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []*Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2s: []Level2{ - { - Level1s: []*Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []*Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - Name string - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload4(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -// Slice: []Level3 -func TestNestedPreload5(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload6(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - - want[1] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value5"}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload7(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - - want[1] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value3"}}, - {Level1: Level1{Value: "value4"}}, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload8(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload9(t *testing.T) { - type ( - Level0 struct { - ID uint - Value string - Level1ID uint - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2_1ID uint - Level0s []Level0 - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level2_1 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - Level2_1 Level2_1 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level2_1{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value1-1", - Level0s: []Level0{{Value: "Level0-1"}}, - }, - { - Value: "value2-2", - Level0s: []Level0{{Value: "Level0-2"}}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value3-3", - Level0s: []Level0{}, - }, - { - Value: "value4-4", - Level0s: []Level0{}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelA1 struct { - ID uint - Value string -} - -type LevelA2 struct { - ID uint - Value string - LevelA3s []*LevelA3 -} - -type LevelA3 struct { - ID uint - Value string - LevelA1ID sql.NullInt64 - LevelA1 *LevelA1 - LevelA2ID sql.NullInt64 - LevelA2 *LevelA2 -} - -func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&LevelA3{}) - DB.DropTableIfExists(&LevelA2{}) - DB.DropTableIfExists(&LevelA1{}) - - if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { - t.Error(err) - } - - levelA1 := &LevelA1{Value: "foo"} - if err := DB.Save(levelA1).Error; err != nil { - t.Error(err) - } - - want := []*LevelA2{ - { - Value: "bar", - LevelA3s: []*LevelA3{ - { - Value: "qux", - LevelA1: levelA1, - }, - }, - }, - { - Value: "bar 2", - LevelA3s: []*LevelA3{}, - }, - } - for _, levelA2 := range want { - if err := DB.Save(levelA2).Error; err != nil { - t.Error(err) - } - } - - var got []*LevelA2 - if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelB1 struct { - ID uint - Value string - LevelB3s []*LevelB3 -} - -type LevelB2 struct { - ID uint - Value string -} - -type LevelB3 struct { - ID uint - Value string - LevelB1ID sql.NullInt64 - LevelB1 *LevelB1 - LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` -} - -func TestNestedPreload11(t *testing.T) { - DB.DropTableIfExists(&LevelB2{}) - DB.DropTableIfExists(&LevelB3{}) - DB.DropTableIfExists(&LevelB1{}) - if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { - t.Error(err) - } - - levelB1 := &LevelB1{Value: "foo"} - if err := DB.Create(levelB1).Error; err != nil { - t.Error(err) - } - - levelB3 := &LevelB3{ - Value: "bar", - LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, - LevelB2s: []*LevelB2{}, - } - if err := DB.Create(levelB3).Error; err != nil { - t.Error(err) - } - levelB1.LevelB3s = []*LevelB3{levelB3} - - want := []*LevelB1{levelB1} - var got []*LevelB1 - if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelC1 struct { - ID uint - Value string - LevelC2ID uint -} - -type LevelC2 struct { - ID uint - Value string - LevelC1 LevelC1 -} - -type LevelC3 struct { - ID uint - Value string - LevelC2ID uint - LevelC2 LevelC2 -} - -func TestNestedPreload12(t *testing.T) { - DB.DropTableIfExists(&LevelC2{}) - DB.DropTableIfExists(&LevelC3{}) - DB.DropTableIfExists(&LevelC1{}) - if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { - t.Error(err) - } - - level2 := LevelC2{ - Value: "c2", - LevelC1: LevelC1{ - Value: "c1", - }, - } - DB.Create(&level2) - - want := []LevelC3{ - { - Value: "c3-1", - LevelC2: level2, - }, { - Value: "c3-2", - LevelC2: level2, - }, - } - - for i := range want { - if err := DB.Create(&want[i]).Error; err != nil { - t.Error(err) - } - } - - var got []LevelC3 - if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { - return - } - - type ( - Level1 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - } - Level2 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - Level1s []Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ - {Value: "ru", LanguageCode: "ru"}, - {Value: "en", LanguageCode: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ - {Value: "zh", LanguageCode: "zh"}, - {Value: "de", LanguageCode: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []Level1{ruLevel1} - got2.Level1s = []Level1{zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } - - if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { - t.Error(err) - } -} - -func TestManyToManyPreloadForNestedPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Bob", - Level2: &Level2{ - Value: "Foo", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level3{ - Value: "Tom", - Level2: &Level2{ - Value: "Bar", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level3 - if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) - } - - var got4 []Level3 - if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level3 - DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level2.Level1s = []*Level1{&ruLevel1} - got2.Level2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) - } -} - -func TestNestedManyToManyPreload(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2s []Level2 `gorm:"many2many:level2_level3;"` - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Level3", - Level2s: []Level2{ - { - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, { - Value: "Tom", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - }, - } - - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedManyToManyPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Level3", - Level2: &Level2{ - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedManyToManyPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - level1Zh := &Level1{Value: "zh"} - level1Ru := &Level1{Value: "ru"} - level1En := &Level1{Value: "en"} - - level21 := &Level2{ - Value: "Level2-1", - Level1s: []*Level1{level1Zh, level1Ru}, - } - - level22 := &Level2{ - Value: "Level2-2", - Level1s: []*Level1{level1Zh, level1En}, - } - - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } - - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } - - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } -} - -func TestNestedManyToManyPreload3ForStruct(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - level1Zh := Level1{Value: "zh"} - level1Ru := Level1{Value: "ru"} - level1En := Level1{Value: "en"} - - level21 := Level2{ - Value: "Level2-1", - Level1s: []Level1{level1Zh, level1Ru}, - } - - level22 := Level2{ - Value: "Level2-2", - Level1s: []Level1{level1Zh, level1En}, - } - - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } - - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } - - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } -} - -func TestNestedManyToManyPreload4(t *testing.T) { - type ( - Level4 struct { - ID uint - Value string - Level3ID uint - } - Level3 struct { - ID uint - Value string - Level4s []*Level4 - } - Level2 struct { - ID uint - Value string - Level3s []*Level3 `gorm:"many2many:level2_level3;"` - } - Level1 struct { - ID uint - Value string - Level2s []*Level2 `gorm:"many2many:level1_level2;"` - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level4{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") - - dummy := Level1{ - Value: "Level1", - Level2s: []*Level2{{ - Value: "Level2", - Level3s: []*Level3{{ - Value: "Level3", - Level4s: []*Level4{{ - Value: "Level4", - }}, - }}, - }}, - } - - if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - if err := DB.Save(&dummy).Error; err != nil { - t.Error(err) - } - - var level1 Level1 - if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { - t.Error(err) - } -} - -func TestManyToManyPreloadForPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level2 - DB.Preload("Level1s").First(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []*Level1{&ruLevel1} - got2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } -} - -func TestNilPointerSlice(t *testing.T) { - type ( - Level3 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level3ID uint - Level3 *Level3 - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level1{ - Value: "Bob", - Level2: &Level2{ - Value: "en", - Level3: &Level3{ - Value: "native", - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level1{ - Value: "Tom", - Level2: nil, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got []Level1 - if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { - t.Error(err) - } - - if len(got) != 2 { - t.Errorf("got %v items, expected 2", len(got)) - } - - if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) - } - - if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) - } -} - -func TestNilPointerSlice2(t *testing.T) { - type ( - Level4 struct { - ID uint - } - Level3 struct { - ID uint - Level4ID sql.NullInt64 `sql:"index"` - Level4 *Level4 - } - Level2 struct { - ID uint - Level3s []*Level3 `gorm:"many2many:level2_level3s"` - } - Level1 struct { - ID uint - Level2ID sql.NullInt64 `sql:"index"` - Level2 *Level2 - } - ) - - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) - - if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { - t.Error(err) - } - - want := new(Level1) - if err := DB.Save(want).Error; err != nil { - t.Error(err) - } - - got := new(Level1) - err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestPrefixedPreloadDuplication(t *testing.T) { - type ( - Level4 struct { - ID uint - Name string - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level4s []*Level4 - } - Level2 struct { - ID uint - Name string - Level3ID sql.NullInt64 `sql:"index"` - Level3 *Level3 - } - Level1 struct { - ID uint - Name string - Level2ID sql.NullInt64 `sql:"index"` - Level2 *Level2 - } - ) - - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) - - if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { - t.Error(err) - } - - lvl := &Level3{} - if err := DB.Save(lvl).Error; err != nil { - t.Error(err) - } - - sublvl1 := &Level4{Level3ID: lvl.ID} - if err := DB.Save(sublvl1).Error; err != nil { - t.Error(err) - } - sublvl2 := &Level4{Level3ID: lvl.ID} - if err := DB.Save(sublvl2).Error; err != nil { - t.Error(err) - } - - lvl.Level4s = []*Level4{sublvl1, sublvl2} - - want1 := Level1{ - Level2: &Level2{ - Level3: lvl, - }, - } - if err := DB.Save(&want1).Error; err != nil { - t.Error(err) - } - - want2 := Level1{ - Level2: &Level2{ - Level3: lvl, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - want := []Level1{want1, want2} - - var got []Level1 - err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestPreloadManyToManyCallbacks(t *testing.T) { - type ( - Level2 struct { - ID uint - Name string - } - Level1 struct { - ID uint - Name string - Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` - } - ) - - DB.DropTableIfExists("level1_level2s") - DB.DropTableIfExists(new(Level1)) - DB.DropTableIfExists(new(Level2)) - - if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { - t.Error(err) - } - - lvl := Level1{ - Name: "l1", - Level2s: []Level2{ - {Name: "l2-1"}, {Name: "l2-2"}, - }, - } - DB.Save(&lvl) - - called := 0 - - DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { - called = called + 1 - }) - - DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) - - if called != 3 { - t.Errorf("Wanted callback to be called 3 times but got %d", called) - } -} - -func toJSONString(v interface{}) []byte { - r, _ := json.MarshalIndent(v, "", " ") - return r -} diff --git a/query_test.go b/query_test.go deleted file mode 100644 index a23a9e24..00000000 --- a/query_test.go +++ /dev/null @@ -1,841 +0,0 @@ -package gorm_test - -import ( - "fmt" - "reflect" - - "github.com/jinzhu/gorm" - - "testing" - "time" -) - -func TestFirstAndLast(t *testing.T) { - DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) - DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) - - var user1, user2, user3, user4 User - DB.First(&user1) - DB.Order("id").Limit(1).Find(&user2) - - ptrOfUser3 := &user3 - DB.Last(&ptrOfUser3) - DB.Order("id desc").Limit(1).Find(&user4) - if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should by order by primary key") - } - - var users []User - DB.First(&users) - if len(users) != 1 { - t.Errorf("Find first record as slice") - } - - var user User - if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { - t.Errorf("Should not raise any error when order with Join table") - } - - if user.Email != "" { - t.Errorf("User's Email should be blank as no one set it") - } -} - -func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { - DB.Save(&Animal{Name: "animal1"}) - DB.Save(&Animal{Name: "animal2"}) - - var animal1, animal2, animal3, animal4 Animal - DB.First(&animal1) - DB.Order("counter").Limit(1).Find(&animal2) - - DB.Last(&animal3) - DB.Order("counter desc").Limit(1).Find(&animal4) - if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { - t.Errorf("First and Last should work correctly") - } -} - -func TestFirstAndLastWithRaw(t *testing.T) { - user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}} - user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}} - DB.Save(&user1) - DB.Save(&user2) - - var user3, user4 User - DB.Raw("select * from users WHERE name = ?", "user").First(&user3) - if user3.Id != user1.Id { - t.Errorf("Find first record with raw") - } - - DB.Raw("select * from users WHERE name = ?", "user").Last(&user4) - if user4.Id != user2.Id { - t.Errorf("Find last record with raw") - } -} - -func TestUIntPrimaryKey(t *testing.T) { - var animal Animal - DB.First(&animal, uint64(1)) - if animal.Counter != 1 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } - - DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) - if animal.Counter != 2 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } -} - -func TestCustomizedTypePrimaryKey(t *testing.T) { - type ID uint - type CustomizedTypePrimaryKey struct { - ID ID - Name string - } - - DB.AutoMigrate(&CustomizedTypePrimaryKey{}) - - p1 := CustomizedTypePrimaryKey{Name: "p1"} - p2 := CustomizedTypePrimaryKey{Name: "p2"} - p3 := CustomizedTypePrimaryKey{Name: "p3"} - DB.Create(&p1) - DB.Create(&p2) - DB.Create(&p3) - - var p CustomizedTypePrimaryKey - - if err := DB.First(&p, p2.ID).Error; err == nil { - t.Errorf("Should return error for invalid query condition") - } - - if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { - t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) - } - - if p.Name != "p2" { - t.Errorf("Should find correct value when querying with customized type for primary key") - } -} - -func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { - type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string - } - - DB.AutoMigrate(&AddressByZipCode{}) - DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) - - var address AddressByZipCode - DB.First(&address, "00501") - if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) - } -} - -func TestFindAsSliceOfPointers(t *testing.T) { - DB.Save(&User{Name: "user"}) - - var users []User - DB.Find(&users) - - var userPointers []*User - DB.Find(&userPointers) - - if len(users) == 0 || len(users) != len(userPointers) { - t.Errorf("Find slice of pointers") - } -} - -func TestSearchWithPlainSQL(t *testing.T) { - user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") - - if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("Search with plain SQL") - } - - if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { - t.Errorf("Search with plan SQL (regexp)") - } - - var users []User - DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) - if len(users) != 2 { - t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) - } - - DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) - } - - scopedb.Where("age <> ?", 20).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users age != 20, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", "2002-10-10").Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) - } - - scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) - } - - DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users, but got %v", len(users)) - } - - DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users, but got %v", len(users)) - } - - DB.Where("id in (?)", user1.Id).Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users, but got %v", len(users)) - } - - if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { - t.Errorf("Should not get RecordNotFound error when looking for none existing records") - } -} - -func TestSearchWithTwoDimensionalArray(t *testing.T) { - var users []User - user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Create(&user1) - DB.Create(&user2) - DB.Create(&user3) - - if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { - if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { - t.Errorf("No error should happen when query with 2D array, but got %v", err) - - if len(users) != 2 { - t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) - } - } - } - - if dialect := DB.Dialect().GetName(); dialect == "mssql" { - if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { - t.Errorf("No error should happen when query with 2D array, but got %v", err) - - if len(users) != 2 { - t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) - } - } - } -} - -func TestSearchWithStruct(t *testing.T) { - user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where(user1.Id).First(&User{}).RecordNotFound() { - t.Errorf("Search with primary key") - } - - if DB.First(&User{}, user1.Id).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - var users []User - DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) - } - - var user User - DB.First(&user, &User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline pointer of struct") - } - - DB.First(&user, User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline struct") - } - - DB.Where(&User{Name: user1.Name}).First(&user) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with where struct") - } - - DB.Find(&users, &User{Name: user2.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline struct") - } -} - -func TestSearchWithMap(t *testing.T) { - companyID := 1 - user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) - - var user User - DB.First(&user, map[string]interface{}{"name": user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline map") - } - - user = User{} - DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) - if user.Id == 0 || user.Name != user2.Name { - t.Errorf("Search first record with where map") - } - - var users []User - DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user3.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) - if len(users) != 0 { - t.Errorf("Search all records with inline map containing null value finding 0 records") - } - - DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) - if len(users) != 1 { - t.Errorf("Search all records with inline map containing null value finding 1 record") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) - if len(users) != 1 { - t.Errorf("Search all records with inline multiple value map") - } -} - -func TestSearchWithEmptyChain(t *testing.T) { - user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where("").Where("").First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty strings") - } - - if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty struct") - } - - if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty map") - } -} - -func TestSelect(t *testing.T) { - user1 := User{Name: "SelectUser1"} - DB.Save(&user1) - - var user User - DB.Where("name = ?", user1.Name).Select("name").Find(&user) - if user.Id != 0 { - t.Errorf("Should not have ID because only selected name, %+v", user.Id) - } - - if user.Name != user1.Name { - t.Errorf("Should have user Name when selected it") - } -} - -func TestOrderAndPluck(t *testing.T) { - user1 := User{Name: "OrderPluckUser1", Age: 1} - user2 := User{Name: "OrderPluckUser2", Age: 10} - user3 := User{Name: "OrderPluckUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") - - var user User - scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user) - if user.Name != "OrderPluckUser2" { - t.Errorf("Order with sql expression") - } - - var ages []int64 - scopedb.Order("age desc").Pluck("age", &ages) - if ages[0] != 20 { - t.Errorf("The first age should be 20 when order with age desc") - } - - var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) - if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") - } - - var ages3, ages4 []int64 - scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) - if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should work") - } - - var names []string - var ages5 []int64 - scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if names != nil && ages5 != nil { - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { - t.Errorf("Order with multiple orders") - } - } else { - t.Errorf("Order with multiple orders") - } - - var ages6 []int64 - if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil { - t.Errorf("An empty string as order clause produces invalid queries") - } - - DB.Model(User{}).Select("name, age").Find(&[]User{}) -} - -func TestLimit(t *testing.T) { - user1 := User{Name: "LimitUser1", Age: 1} - user2 := User{Name: "LimitUser2", Age: 10} - user3 := User{Name: "LimitUser3", Age: 20} - user4 := User{Name: "LimitUser4", Age: 10} - user5 := User{Name: "LimitUser5", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) - - var users1, users2, users3 []User - DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) - - if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") - } -} - -func TestOffset(t *testing.T) { - for i := 0; i < 20; i++ { - DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) - } - var users1, users2, users3, users4 []User - DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) - - if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should work") - } -} - -func TestLimitAndOffsetSQL(t *testing.T) { - user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10} - user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20} - user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30} - user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40} - user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50} - if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - limit, offset interface{} - users []*User - ok bool - }{ - { - name: "OK", - limit: float64(2), - offset: float64(2), - users: []*User{ - &User{Name: "TestLimitAndOffsetSQL3", Age: 30}, - &User{Name: "TestLimitAndOffsetSQL2", Age: 20}, - }, - ok: true, - }, - { - name: "Limit parse error", - limit: float64(1000000), // 1e+06 - offset: float64(2), - ok: false, - }, - { - name: "Offset parse error", - limit: float64(2), - offset: float64(1000000), // 1e+06 - ok: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var users []*User - err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error - if tt.ok { - if err != nil { - t.Errorf("error expected nil, but got %v", err) - } - if len(users) != len(tt.users) { - t.Errorf("users length expected %d, but got %d", len(tt.users), len(users)) - } - for i := range tt.users { - if users[i].Name != tt.users[i].Name { - t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name) - } - if users[i].Age != tt.users[i].Age { - t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age) - } - } - } else { - if err == nil { - t.Error("error expected not nil, but got nil") - } - } - }) - } -} - -func TestOr(t *testing.T) { - user1 := User{Name: "OrUser1", Age: 1} - user2 := User{Name: "OrUser2", Age: 10} - user3 := User{Name: "OrUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users []User - DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) - if len(users) != 2 { - t.Errorf("Find users with or") - } -} - -func TestCount(t *testing.T) { - user1 := User{Name: "CountUser1", Age: 1} - user2 := User{Name: "CountUser2", Age: 10} - user3 := User{Name: "CountUser3", Age: 20} - - DB.Save(&user1).Save(&user2).Save(&user3) - var count, count1, count2 int64 - var users []User - - if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { - t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) - } - - if count != int64(len(users)) { - t.Errorf("Count() method should get correct value") - } - - DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) - if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count in chain") - } - - var count3 int - if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { - t.Errorf("Not error should happen, but got %v", err) - } - - if count3 != 2 { - t.Errorf("Should get correct count, but got %v", count3) - } -} - -func TestNot(t *testing.T) { - DB.Create(getPreparedUser("user1", "not")) - DB.Create(getPreparedUser("user2", "not")) - DB.Create(getPreparedUser("user3", "not")) - - user4 := getPreparedUser("user4", "not") - user4.Company = Company{} - DB.Create(user4) - - DB := DB.Where("role = ?", "not") - - var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User - if DB.Find(&users1).RowsAffected != 4 { - t.Errorf("should find 4 not users") - } - DB.Not(users1[0].Id).Find(&users2) - - if len(users1)-len(users2) != 1 { - t.Errorf("Should ignore the first users with Not") - } - - DB.Not([]int{}).Find(&users3) - if len(users1)-len(users3) != 0 { - t.Errorf("Should find all users with a blank condition") - } - - var name3Count int64 - DB.Table("users").Where("name = ?", "user3").Count(&name3Count) - DB.Not("name", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not("name = ?", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not("name <> ?", "user3").Find(&users4) - if len(users4) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not(User{Name: "user3"}).Find(&users5) - - if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) - if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) - if len(users1)-len(users7) != 2 { // not user3 or user4 - t.Errorf("Should find all user's name not equal to 3 who do not have company id") - } - - DB.Not("name", []string{"user3"}).Find(&users8) - if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - var name2Count int64 - DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users9) - if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users' name not equal 3") - } -} - -func TestFillSmallerStruct(t *testing.T) { - user1 := User{Name: "SmallerUser", Age: 100} - DB.Save(&user1) - type SimpleUser struct { - Name string - Id int64 - UpdatedAt time.Time - CreatedAt time.Time - } - - var simpleUser SimpleUser - DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) - - if simpleUser.Id == 0 || simpleUser.Name == "" { - t.Errorf("Should fill data correctly into smaller struct") - } -} - -func TestFindOrInitialize(t *testing.T) { - var user1, user2, user3, user4, user5, user6 User - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) - if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) - if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) - if user3.Name != "find or init 2" || user3.Id != 0 { - t.Errorf("user should be initialized with inline search value") - } - - DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and attrs") - } - - DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and assign attrs") - } - - DB.Save(&User{Name: "find or init", Age: 33}) - DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { - t.Errorf("user should be found with FirstOrInit") - } - - DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } -} - -func TestFindOrCreate(t *testing.T) { - var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) - if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) - if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) - if user3.Name != "find or create 2" || user3.Id == 0 { - t.Errorf("user should be created with inline search value") - } - - DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) - if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and attrs") - } - - updatedAt1 := user4.UpdatedAt - DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) - if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateAt should be changed when update values with assign") - } - - DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) - if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) - if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Find(&user7) - if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) - if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { - t.Errorf("embedded struct email should be saved") - } - - if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { - t.Errorf("embedded struct credit card should be saved") - } -} - -func TestSelectWithEscapedFieldName(t *testing.T) { - user1 := User{Name: "EscapedFieldNameUser", Age: 1} - user2 := User{Name: "EscapedFieldNameUser", Age: 10} - user3 := User{Name: "EscapedFieldNameUser", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var names []string - DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) - - if len(names) != 3 { - t.Errorf("Expected 3 name, but got: %d", len(names)) - } -} - -func TestSelectWithVariables(t *testing.T) { - DB.Save(&User{Name: "jinzhu"}) - - rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() - - if !rows.Next() { - t.Errorf("Should have returned at least one row") - } else { - columns, _ := rows.Columns() - if !reflect.DeepEqual(columns, []string{"fake"}) { - t.Errorf("Should only contains one column") - } - } - - rows.Close() -} - -func TestSelectWithArrayInput(t *testing.T) { - DB.Save(&User{Name: "jinzhu", Age: 42}) - - var user User - DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) - - if user.Name != "jinzhu" || user.Age != 42 { - t.Errorf("Should have selected both age and name") - } -} - -func TestPluckWithSelect(t *testing.T) { - var ( - user = User{Name: "matematik7_pluck_with_select", Age: 25} - combinedName = fmt.Sprintf("%v%v", user.Name, user.Age) - combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) - ) - - if dialect := DB.Dialect().GetName(); dialect == "sqlite3" { - combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) - } - - DB.Save(&user) - - selectStr := combineUserAgeSQL + " as user_age" - var userAges []string - err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error - if err != nil { - t.Error(err) - } - - if len(userAges) != 1 || userAges[0] != combinedName { - t.Errorf("Should correctly pluck with select, got: %s", userAges) - } - - selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) - userAges = userAges[:0] - err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error - if err != nil { - t.Error(err) - } - - if len(userAges) != 1 || userAges[0] != combinedName { - t.Errorf("Should correctly pluck with select, got: %s", userAges) - } -} diff --git a/scaner_test.go b/scaner_test.go deleted file mode 100644 index 9e251dd6..00000000 --- a/scaner_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestScannableSlices(t *testing.T) { - if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { - t.Errorf("Should create table with slice values correctly: %s", err) - } - - r1 := RecordWithSlice{ - Strings: ExampleStringSlice{"a", "b", "c"}, - Structs: ExampleStructSlice{ - {"name1", "value1"}, - {"name2", "value2"}, - }, - } - - if err := DB.Save(&r1).Error; err != nil { - t.Errorf("Should save record with slice values") - } - - var r2 RecordWithSlice - - if err := DB.Find(&r2).Error; err != nil { - t.Errorf("Should fetch record with slice values") - } - - if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { - t.Errorf("Should have serialised and deserialised a string array") - } - - if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { - t.Errorf("Should have serialised and deserialised a struct array") - } -} - -type RecordWithSlice struct { - ID uint64 - Strings ExampleStringSlice `sql:"type:text"` - Structs ExampleStructSlice `sql:"type:text"` -} - -type ExampleStringSlice []string - -func (l ExampleStringSlice) Value() (driver.Value, error) { - bytes, err := json.Marshal(l) - return string(bytes), err -} - -func (l *ExampleStringSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ExampleStruct struct { - Name string - Value string -} - -type ExampleStructSlice []ExampleStruct - -func (l ExampleStructSlice) Value() (driver.Value, error) { - bytes, err := json.Marshal(l) - return string(bytes), err -} - -func (l *ExampleStructSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ScannerDataType struct { - Street string `sql:"TYPE:varchar(24)"` -} - -func (ScannerDataType) Value() (driver.Value, error) { - return nil, nil -} - -func (*ScannerDataType) Scan(input interface{}) error { - return nil -} - -type ScannerDataTypeTestStruct struct { - Field1 int - ScannerDataType *ScannerDataType `sql:"TYPE:json"` -} - -type ScannerDataType2 struct { - Street string `sql:"TYPE:varchar(24)"` -} - -func (ScannerDataType2) Value() (driver.Value, error) { - return nil, nil -} - -func (*ScannerDataType2) Scan(input interface{}) error { - return nil -} - -type ScannerDataTypeTestStruct2 struct { - Field1 int - ScannerDataType *ScannerDataType2 -} - -func TestScannerDataType(t *testing.T) { - scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}} - if field, ok := scope.FieldByName("ScannerDataType"); ok { - if DB.Dialect().DataTypeOf(field.StructField) != "json" { - t.Errorf("data type for scanner is wrong") - } - } - - scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}} - if field, ok := scope.FieldByName("ScannerDataType"); ok { - if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" { - t.Errorf("data type for scanner is wrong") - } - } -} diff --git a/scope.go b/scope.go deleted file mode 100644 index d82cadbc..00000000 --- a/scope.go +++ /dev/null @@ -1,1421 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "regexp" - "strings" - "time" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() SQLCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Contains(str, ".") { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToColumnName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - _, skipBindVar := scope.InstanceGet("skip_bindvar") - - if expr, ok := value.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - if skipBindVar { - scope.AddToVars(arg) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - - if skipBindVar { - return "?" - } - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Contains(scope.Search.tableName, " ") { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - return joinSQL + whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); scope.Err(err) == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var ( - columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` - isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") - countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") -) - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - values = make([]interface{}, len(columns)) - selectFields []*Field - selectedColumnsMap = map[string]int{} - resetFields = map[int]*Field{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - offset := 0 - if idx, ok := selectedColumnsMap[column]; ok { - offset = idx + 1 - selectFields = selectFields[offset:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[index] = field - } - - selectedColumnsMap[column] = offset + fieldIndex - - if field.IsNormal { - break - } - } - } - } - - scope.Err(rows.Scan(values...)) - - for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { - var ( - quotedTableName = scope.QuotedTableName() - quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) - equalSQL = "=" - inSQL = "IN" - ) - - // If building not conditions - if !include { - equalSQL = "<>" - inSQL = "NOT IN" - } - - switch value := clause["query"].(type) { - case sql.NullInt64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - if !include && reflect.ValueOf(value).Len() == 0 { - return - } - str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) - clause["args"] = []interface{}{value} - case string: - if isNumberRegexp.MatchString(value) { - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) - } - - if value != "" { - if !include { - if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) - } - } else { - str = fmt.Sprintf("(%v)", value) - } - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) - } else { - if !include { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) - } - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - - if len(newScope.Fields()) == 0 { - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - scopeQuotedTableName := newScope.QuotedTableName() - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - default: - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if as, ok := arg.([][]interface{}); ok { - var tempMarks []string - for _, a := range as { - var arrayMarks []string - for _, v := range a { - arrayMarks = append(arrayMarks, scope.AddToVars(v)) - } - - if len(arrayMarks) > 0 { - tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) - } - } - - if len(tempMarks) > 0 { - replacements = append(replacements, strings.Join(tempMarks, ",")) - } - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = valuer.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for _, s := range str { - if s == '?' && len(replacements) > i { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(s) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - replacements := []string{} - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - replacements = append(replacements, scope.AddToVars(arg)) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for pos, char := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(char) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && hasDeletedAtField { - sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildCondition(clause, false); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) - scope.Err(err) - return sql -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - defer func() { - if err := recover(); err != nil { - if db, ok := scope.db.db.(sqlTx); ok { - db.Rollback() - } - panic(err) - } - }() - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values, db: db}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false, scope.db), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true, scope.db) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*SqlExpr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal && !field.IsIgnored { - hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - - result := &RowQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Row -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - - result := &RowsQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Rows, result.Error -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { - queryStr := strings.ToLower(fmt.Sprint(query)) - if queryStr == column { - return true - } - - if strings.HasSuffix(queryStr, "as "+column) { - return true - } - - if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { - return true - } - - return false -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - if dest.Len() > 0 { - dest.Set(reflect.Zero(dest.Type())) - } - - if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { - scope.Search.Select(column) - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - if len(scope.Search.group) != 0 { - if len(scope.Search.havingConditions) != 0 { - scope.prepareQuerySQL() - scope.Search = &search{} - scope.Search.Select("count(*)") - scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) - } else { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" - } - } else { - scope.Search.Select("count(*)") - } - } - scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - tx := scope.db.Set("gorm:association:source", scope.Value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(tx.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - scope.Err(tx.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return " " + tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - // Compatible with old generated key - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var mysql mysql - var query string - if scope.Dialect().GetName() == mysql.GetName() { - query = `ALTER TABLE %s DROP FOREIGN KEY %s;` - } else { - query = `ALTER TABLE %s DROP CONSTRAINT %s;` - } - - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettingsGet("INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - indexes[name] = append(indexes[name], column) - } - } - - if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - uniqueIndexes[name] = append(uniqueIndexes[name], column) - } - } - } - - for name, columns := range indexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - for name, columns := range uniqueIndexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - resultMap := make(map[string][]interface{}) - for _, value := range values { - indirectValue := indirect(reflect.ValueOf(value)) - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - var hasValue = false - for _, column := range columns { - field := object.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - case reflect.Struct: - var result []interface{} - var hasValue = false - for _, column := range columns { - field := indirectValue.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - } - for _, v := range resultMap { - results = append(results, v) - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} - -func (scope *Scope) hasConditions() bool { - return !scope.PrimaryKeyZero() || - len(scope.Search.whereConditions) > 0 || - len(scope.Search.orConditions) > 0 || - len(scope.Search.notConditions) > 0 -} diff --git a/scope_test.go b/scope_test.go deleted file mode 100644 index f7f1ed08..00000000 --- a/scope_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package gorm_test - -import ( - "encoding/hex" - "math/rand" - "strings" - "testing" - - "github.com/jinzhu/gorm" -) - -func NameIn1And2(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) -} - -func NameIn2And3(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) -} - -func NameIn(names []string) func(d *gorm.DB) *gorm.DB { - return func(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", names) - } -} - -func TestScopes(t *testing.T) { - user1 := User{Name: "ScopeUser1", Age: 1} - user2 := User{Name: "ScopeUser2", Age: 1} - user3 := User{Name: "ScopeUser3", Age: 2} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users1, users2, users3 []User - DB.Scopes(NameIn1And2).Find(&users1) - if len(users1) != 2 { - t.Errorf("Should found two users's name in 1, 2") - } - - DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) - if len(users2) != 1 { - t.Errorf("Should found one user's name is 2") - } - - DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) - if len(users3) != 2 { - t.Errorf("Should found two users's name in 1, 3") - } -} - -func randName() string { - data := make([]byte, 8) - rand.Read(data) - - return "n-" + hex.EncodeToString(data) -} - -func TestValuer(t *testing.T) { - name := randName() - - origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} - if err := DB.Save(&origUser).Error; err != nil { - t.Errorf("No error should happen when saving user, but got %v", err) - } - - var user2 User - if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { - t.Errorf("No error should happen when querying user with valuer, but got %v", err) - } -} - -func TestFailedValuer(t *testing.T) { - name := randName() - - err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error - - if err == nil { - t.Errorf("There should be an error should happen when insert data") - } else if !strings.HasPrefix(err.Error(), "Should not start with") { - t.Errorf("The error should be returned from Valuer, but get %v", err) - } -} - -func TestDropTableWithTableOptions(t *testing.T) { - type UserWithOptions struct { - gorm.Model - } - DB.AutoMigrate(&UserWithOptions{}) - - DB = DB.Set("gorm:table_options", "CHARSET=utf8") - err := DB.DropTable(&UserWithOptions{}).Error - if err != nil { - t.Errorf("Table must be dropped, got error %s", err) - } -} diff --git a/search.go b/search.go deleted file mode 100644 index 7c4cc184..00000000 --- a/search.go +++ /dev/null @@ -1,153 +0,0 @@ -package gorm - -import ( - "fmt" -) - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - ignoreOrderQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil && value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*SqlExpr); ok { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) - } else { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - } - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/search_test.go b/search_test.go deleted file mode 100644 index 4db7ab6a..00000000 --- a/search_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import ( - "reflect" - "testing" -) - -func TestCloneSearch(t *testing.T) { - s := new(search) - s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") - - s1 := s.clone() - s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") - - if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { - t.Errorf("Where should be copied") - } - - if reflect.DeepEqual(s.orders, s1.orders) { - t.Errorf("Order should be copied") - } - - if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { - t.Errorf("InitAttrs should be copied") - } - - if reflect.DeepEqual(s.Select, s1.Select) { - t.Errorf("selectStr should be copied") - } -} diff --git a/test_all.sh b/test_all.sh deleted file mode 100755 index 5cfb3321..00000000 --- a/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") - -for dialect in "${dialects[@]}" ; do - DEBUG=false GORM_DIALECT=${dialect} go test -done diff --git a/update_test.go b/update_test.go deleted file mode 100644 index 85d53e5f..00000000 --- a/update_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T) { - product1 := Product{Code: "product1code"} - product2 := Product{Code: "product2code"} - - DB.Save(&product1).Save(&product2).Update("code", "product2newcode") - - if product2.Code != "product2newcode" { - t.Errorf("Record should be updated") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt1 := product1.UpdatedAt - - if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { - t.Errorf("Product1 should not be updated") - } - - if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") - - var product4 Product - DB.First(&product4, product1.Id) - if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { - t.Errorf("Product1's code should be updated") - } - - if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { - t.Errorf("Product should not be changed to 789") - } - - if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update with CamelCase") - } - - if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update_column with CamelCase") - } - - var products []Product - DB.Find(&products) - if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { - t.Error("RowsAffected should be correct when do batch update") - } - - DB.First(&product4, product4.Id) - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("Update with expression") - } - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Update with expression should update UpdatedAt") - } -} - -func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - DB.Save(&animal) - updatedAt1 := animal.UpdatedAt - - DB.Save(&animal).Update("name", "Francis") - - if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } - - var animals []Animal - DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { - t.Error("RowsAffected should be correct when do batch update") - } - - animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) - DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) - if animal.Name != "galeone" { - t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) - } - - // When changing a field with a default value, the change must occur - animal.Name = "amazing horse" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "amazing horse" { - t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) - } - - // When changing a field with a default value with blank value - animal.Name = "" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) - } -} - -func TestUpdates(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 10} - DB.Save(&product1).Save(&product2) - DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) - if product1.Code != "product1newcode" || product1.Price != 100 { - t.Errorf("Record should be updated also with map") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - - if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { - t.Errorf("Product2 should not be updated") - } - - if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { - t.Errorf("Product1 should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) - if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("product2's code should be updated") - } - - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100 { - t.Errorf("Updates with expression") - } - // product4's UpdatedAt will be reset when updating - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Updates with expression should update UpdatedAt") - } -} - -func TestUpdateColumn(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 20} - DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) - if product2.Code != "product2newcode" || product2.Price != 100 { - t.Errorf("product 2 should be updated with update column") - } - - var product3 Product - DB.First(&product3, product1.Id) - if product3.Code != "product1code" || product3.Price != 10 { - t.Errorf("product 1 should not be updated") - } - - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - DB.Model(product2).UpdateColumn("code", "update_column_new") - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated with update column") - } - - DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("UpdateColumn with expression") - } - if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateColumn with expression should not update UpdatedAt") - } -} - -func TestSelectWithUpdate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestSelectWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestOmitWithUpdate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships that not omitted") - } -} - -func TestOmitWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships not omitted") - } -} - -func TestSelectWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } -} - -func TestOmitWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should omit name column when update user") - } -} - -func TestUpdateColumnsSkipsAssociations(t *testing.T) { - user := getPreparedUser("update_columns_user", "special_role") - user.Age = 99 - address1 := "first street" - user.BillingAddress = Address{Address1: address1} - DB.Save(user) - - // Update a single field of the user and verify that the changed address is not stored. - newAge := int64(100) - user.BillingAddress.Address1 = "second street" - db := DB.Model(user).UpdateColumns(User{Age: newAge}) - if db.RowsAffected != 1 { - t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) - } - - // Verify that Age now=`newAge`. - freshUser := &User{Id: user.Id} - DB.First(freshUser) - if freshUser.Age != newAge { - t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) - } - - // Verify that user's BillingAddress.Address1 is not changed and is still "first street". - DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) - if freshUser.BillingAddress.Address1 != address1 { - t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) - } -} - -func TestUpdatesWithBlankValues(t *testing.T) { - product := Product{Code: "product1", Price: 10} - DB.Save(&product) - - DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) - - var product1 Product - DB.First(&product1, product.Id) - - if product1.Code != "product1" || product1.Price != 100 { - t.Errorf("product's code should not be updated") - } -} - -type ElementWithIgnoredField struct { - Id int64 - Value string - IgnoredField int64 `sql:"-"` -} - -func (e ElementWithIgnoredField) TableName() string { - return "element_with_ignored_field" -} - -func TestUpdatesTableWithIgnoredValues(t *testing.T) { - elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} - DB.Save(&elem) - - DB.Table(elem.TableName()). - Where("id = ?", elem.Id). - // DB.Model(&ElementWithIgnoredField{Id: elem.Id}). - Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) - - var elem1 ElementWithIgnoredField - err := DB.First(&elem1, elem.Id).Error - if err != nil { - t.Errorf("error getting an element from database: %s", err.Error()) - } - - if elem1.IgnoredField != 0 { - t.Errorf("element's ignored field should not be updated") - } -} - -func TestUpdateDecodeVirtualAttributes(t *testing.T) { - var user = User{ - Name: "jinzhu", - IgnoreMe: 88, - } - - DB.Save(&user) - - DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) - - if user.IgnoreMe != 100 { - t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") - } -} diff --git a/utils.go b/utils.go deleted file mode 100644 index d2ae9465..00000000 --- a/utils.go +++ /dev/null @@ -1,226 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -// SQL expression -type SqlExpr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *SqlExpr { - return &SqlExpr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/wercker.yml b/wercker.yml deleted file mode 100644 index c74fa4d4..00000000 --- a/wercker.yml +++ /dev/null @@ -1,154 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - name: mariadb - id: mariadb:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql57 - id: mysql:5.7 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql56 - id: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: postgres - id: postgres:latest - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres96 - id: postgres:9.6 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres95 - id: postgres:9.5 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres94 - id: postgres:9.4 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres93 - id: postgres:9.3 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: mssql - id: mcmoe/mssqldocker:latest - env: - ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t -v ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - go test -race -v ./... - - - script: - name: test mariadb - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.7 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.6 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres96 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres95 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres94 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres93 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test mssql - code: | - GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... - - - script: - name: codecov - code: | - go test -race -coverprofile=coverage.txt -covermode=atomic ./... - bash <(curl -s https://codecov.io/bash)