diff --git a/callbacks/create.go b/callbacks/create.go index 0fe1dc93..f0b78139 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -302,7 +303,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { if field.AutoUpdateTime > 0 { assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} switch field.AutoUpdateTime { diff --git a/callbacks/query.go b/callbacks/query.go index 04b5c657..95db1f0a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,6 +8,8 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func Query(db *gorm.DB) { @@ -109,86 +111,136 @@ func BuildQuerySQL(db *gorm.DB) { } } + specifiedRelationsName := make(map[string]interface{}) for _, join := range db.Statement.Joins { - if db.Statement.Schema == nil { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { - tableAliasName := relation.Name - - columnStmt := gorm.Statement{ - Table: tableAliasName, DB: db, Schema: relation.FieldSchema, - Selects: join.Selects, Omits: join.Omits, - } - - selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) - for _, s := range relation.FieldSchema.DBNames { - if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) - } - } - - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } } - } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } + + if isNestedJoin { + isRelations = true + relations = gussNestedRelations } } } - { - onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} - for _, c := range relation.FieldSchema.QueryClauses { - onStmt.AddClause(c) - } + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } - if join.On != nil { - onStmt.AddClause(join.On) - } + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } - if cs, ok := onStmt.Clauses["WHERE"]; ok { - if where, ok := cs.Expression.(clause.Where); ok { - where.Build(&onStmt) + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) + } + } - if onSQL := onStmt.SQL.String(); onSQL != "" { - vars := onStmt.Vars - for idx, v := range vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } } - - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) } } - } - } - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: join.JoinType, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + parentTableName = rel.Name + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, diff --git a/callbacks/update.go b/callbacks/update.go index fe6f0994..4eb75788 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -245,11 +245,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } default: updatingSchema := stmt.Schema + var isDiffSchema bool if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { // different schema updatingStmt := &gorm.Statement{DB: stmt.DB} if err := updatingStmt.Parse(stmt.Dest); err == nil { updatingSchema = updatingStmt.Schema + isDiffSchema = true } } @@ -276,7 +278,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if (ok || !isZero) && field.Updatable { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + assignField := field + if isDiffSchema { + if originField := stmt.Schema.LookUpField(dbName); originField != nil { + assignField = originField + } + } + assignValue(assignField, value) } } } else { diff --git a/finisher_api.go b/finisher_api.go index f16d4f43..e6fe4666 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,6 +6,8 @@ import ( "fmt" "reflect" "strings" + "sync" + "sync/atomic" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -608,6 +610,15 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } +var ( + savepointIdx int64 + savepointNamePool = &sync.Pool{ + New: func() interface{} { + return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1)) + }, + } +) + // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // they are rolled back. @@ -617,7 +628,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + poolName := savepointNamePool.Get() + defer savepointNamePool.Put(poolName) + err = db.SavePoint(poolName.(string)).Error if err != nil { return } @@ -625,7 +638,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) + db.RollbackTo(poolName.(string)) } }() } diff --git a/gorm.go b/gorm.go index b5d98196..9a70c3d2 100644 --- a/gorm.go +++ b/gorm.go @@ -347,14 +347,16 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { - if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { - err = errTranslator.Translate(err) - } + if err != nil { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } - if db.Error == nil { - db.Error = err - } else if err != nil { - db.Error = fmt.Errorf("%v; %w", db.Error, err) + if db.Error == nil { + db.Error = err + } else { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } } return db.Error } diff --git a/scan.go b/scan.go index 12a77862..736db4d3 100644 --- a/scan.go +++ b/scan.go @@ -4,10 +4,10 @@ import ( "database/sql" "database/sql/driver" "reflect" - "strings" "time" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // prepareValues prepare values slice @@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() @@ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}) + joinedNestedSchemaMap := make(map[string]interface{}) for idx, field := range fields { if field == nil { continue } - if len(joinFields) == 0 || joinFields[idx][0] == nil { + if len(joinFields) == 0 || len(joinFields[idx]) == 0 { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + } else { // joinFields count is larger than 2 when using join + var isNilPtrValue bool + var relValue reflect.Value + // does not contain raw dbname + nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] + // current reflect value + currentReflectValue := reflectValue + fullRels := make([]string, 0, len(nestedJoinSchemas)) + for _, joinSchema := range nestedJoinSchemas { + fullRels = append(fullRels, joinSchema.Name) + relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) + if relValue.Kind() == reflect.Ptr { + fullRelsName := utils.JoinNestedRelationNames(fullRels) + // same nested structure + if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + isNilPtrValue = true + break + } - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedNestedSchemaMap[fullRelsName] = nil + } } + currentReflectValue = relValue + } + + if !isNilPtrValue { // ignore if value is nil + f := joinFields[idx][len(joinFields[idx])-1] + db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } // release data to pool @@ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { default: var ( fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field + joinFields [][]*schema.Field sch = db.Statement.Schema reflectValue = db.Statement.ReflectValue ) @@ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } else { matchedFieldCount[column] = 1 } - } else if names := strings.Split(column, "__"); len(names) > 1 { + } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + subNameCount := len(names) + // nested relation fields + relFields := make([]*schema.Field, 0, subNameCount-1) + relFields = append(relFields, rel.Field) + for _, name := range names[1 : subNameCount-1] { + rel = rel.FieldSchema.Relationships.Relations[name] + relFields = append(relFields, rel.Field) + } + // lastest name is raw dbname + dbName := names[subNameCount-1] + if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { fields[idx] = field if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) + joinFields = make([][]*schema.Field, len(columns)) } - joinFields[idx] = [2]*schema.Field{rel.Field, field} + relFields = append(relFields, field) + joinFields[idx] = relFields continue } } diff --git a/schema/field.go b/schema/field.go index d93ec538..921f3c44 100644 --- a/schema/field.go +++ b/schema/field.go @@ -920,6 +920,8 @@ func (field *Field) setupValuerAndSetter() { sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() } + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { if s.fieldValue != nil { @@ -927,11 +929,12 @@ func (field *Field) setupValuerAndSetter() { } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } else if sameType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + s.Serializer = si.Interface().(SerializerInterface) } } else { err = oldFieldSetter(ctx, value, v) @@ -943,11 +946,15 @@ func (field *Field) setupValuerAndSetter() { func (field *Field) setupNewValuePool() { if field.Serializer != nil { + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.NewValuePool = &sync.Pool{ New: func() interface{} { + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) return &serializer{ Field: field, - Serializer: field.Serializer, + Serializer: si.Interface().(SerializerInterface), } }, } diff --git a/schema/schema.go b/schema/schema.go index 6e5cb4a9..e13a5ed1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -250,8 +250,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { - schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + if schema.PrioritizedPrimaryField == nil { + if len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } else if len(schema.PrimaryFields) > 1 { + // If there are multiple primary keys, the AUTOINCREMENT field is prioritized + for _, field := range schema.PrimaryFields { + if field.AutoIncrement { + schema.PrioritizedPrimaryField = field + break + } + } + } } for _, field := range schema.PrimaryFields { diff --git a/schema/schema_test.go b/schema/schema_test.go index 8a752fb7..5bc0fb83 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -293,3 +293,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { }) } } + +func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { + type Product struct { + ProductID uint `gorm:"primaryKey;autoIncrement"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + type ProductNonAutoIncrement struct { + ProductID uint `gorm:"primaryKey;autoIncrement:false"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + prioritizedPrimaryField := schema.Field{ + Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"}, + } + + product.Fields = []*schema.Field{product.PrioritizedPrimaryField} + + checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + + productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err) + } + + if productNonAutoIncrement.PrioritizedPrimaryField != nil { + t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") + } +} diff --git a/tests/create_test.go b/tests/create_test.go index 274a7f48..75aa8cba 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -547,3 +547,68 @@ func TestFirstOrCreateRowsAffected(t *testing.T) { t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) } } + +func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { + type CompositeKeyProduct struct { + ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key + LanguageCode int `gorm:"primaryKey;"` // primary key + Code string + Name string + } + + if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + prod := &CompositeKeyProduct{ + LanguageCode: 56, + Code: "Code56", + Name: "ProductName56", + } + if err := DB.Create(&prod).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + newProd := &CompositeKeyProduct{} + if err := DB.First(&newProd).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name") + } +} + +func TestCreateOnConfilctWithDefalutNull(t *testing.T) { + type OnConfilctUser struct { + ID string + Name string `gorm:"default:null"` + Email string + Mobile string `gorm:"default:'133xxxx'"` + } + + err := DB.Migrator().DropTable(&OnConfilctUser{}) + AssertEqual(t, err, nil) + err = DB.AutoMigrate(&OnConfilctUser{}) + AssertEqual(t, err, nil) + + u := OnConfilctUser{ + ID: "on-confilct-user-id", + Name: "on-confilct-user-name", + Email: "on-confilct-user-email", + Mobile: "on-confilct-user-mobile", + } + err = DB.Create(&u).Error + AssertEqual(t, err, nil) + + u.Name = "on-confilct-user-name-2" + u.Email = "on-confilct-user-email-2" + u.Mobile = "" + err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error + AssertEqual(t, err, nil) + + var u2 OnConfilctUser + err = DB.Where("id = ?", u.ID).First(&u2).Error + AssertEqual(t, err, nil) + AssertEqual(t, u2.Name, "on-confilct-user-name-2") + AssertEqual(t, u2.Email, "on-confilct-user-email-2") + AssertEqual(t, u2.Mobile, "133xxxx") +} diff --git a/tests/go.mod b/tests/go.mod index b2d5ca97..306a530e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,12 @@ require ( github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/microsoft/go-mssqldb v0.20.0 // indirect + golang.org/x/crypto v0.7.0 // indirect gorm.io/driver/mysql v1.4.7 - gorm.io/driver/postgres v1.4.8 + gorm.io/driver/postgres v1.5.0 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.2 - gorm.io/gorm v1.24.5 + gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index 057ad333..e6715bbe 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -325,3 +325,66 @@ func TestJoinArgsWithDB(t *testing.T) { } AssertEqual(t, user4.NamedPet.Name, "") } + +func TestNestedJoins(t *testing.T) { + users := []User{ + { + Name: "nested-joins-1", + Manager: GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, + }, + { + Name: "nested-joins-2", + Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}}, + }, + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB. + Joins("Manager"). + Joins("Manager.Company"). + Joins("Manager.NamedPet"). + Joins("NamedPet"). + Joins("NamedPet.Toy"). + Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + // user + CheckUser(t, user, users2[idx]) + if users2[idx].Manager == nil { + t.Fatalf("Failed to load Manager") + } + // manager + CheckUser(t, *user.Manager, *users2[idx].Manager) + // user pet + if users2[idx].NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) + // manager pet + if users2[idx].Manager.NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 11a0afda..69f86412 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1542,3 +1542,59 @@ func TestMigrateView(t *testing.T) { t.Fatalf("Failed to drop view, got %v", err) } } + +func TestMigrateExistingBoolColumnPG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ColumnStruct struct { + gorm.Model + Name string + StringBool string + SmallintBool int `gorm:"type:smallint"` + } + + type ColumnStruct2 struct { + gorm.Model + Name string + StringBool bool // change existing boolean column from string to boolean + SmallintBool bool // change existing boolean column from smallint or other to boolean + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "string_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + case "smallint_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + } + } + } +} diff --git a/tests/update_test.go b/tests/update_test.go index d7634580..b2da11c6 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -773,3 +773,16 @@ func TestUpdateReturning(t *testing.T) { t.Errorf("failed to return updated age column") } } + +func TestUpdateWithDiffSchema(t *testing.T) { + user := GetUser("update-diff-schema-1", Config{}) + DB.Create(&user) + + type UserTemp struct { + Name string + } + + err := DB.Model(&user).Updates(&UserTemp{Name: "update-diff-schema-2"}).Error + AssertEqual(t, err, nil) + AssertEqual(t, "update-diff-schema-2", user.Name) +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 661d727f..49d01f2e 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -13,8 +13,14 @@ import ( func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { - got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + rv := reflect.Indirect(reflect.ValueOf(r)) + ev := reflect.Indirect(reflect.ValueOf(e)) + if rv.IsValid() != ev.IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) + return + } + got := rv.FieldByName(name).Interface() + expect := ev.FieldByName(name).Interface() t.Run(name, func(t *testing.T) { AssertEqual(t, got, expect) }) diff --git a/utils/utils.go b/utils/utils.go index e08533cd..ddbca60a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -131,3 +131,20 @@ func ToString(value interface{}) string { } return "" } + +const nestedRelationSplit = "__" + +// NestedRelationName nested relationships like `Manager__Company` +func NestedRelationName(prefix, name string) string { + return prefix + nestedRelationSplit + name +} + +// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` +func SplitNestedRelationName(name string) []string { + return strings.Split(name, nestedRelationSplit) +} + +// JoinNestedRelationNames nested relationships like `Manager__Company` +func JoinNestedRelationNames(relationNames []string) string { + return strings.Join(relationNames, nestedRelationSplit) +}