diff --git a/finisher_api.go b/finisher_api.go index 2590e422..c47e12af 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -129,7 +129,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]) + exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignExprsToValue(exprs) } tx.Error = nil @@ -137,19 +137,54 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignExprsToValue(exprs) } return } -func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() - // if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { - // // initialize with attrs, conds - // } + if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { + tx.Error = nil + + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignExprsToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) + tx.assignExprsToValue(exprs) + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + tx.assignExprsToValue(exprs) + } + + return tx.Create(dest) + } else if len(tx.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + default: + } + } + } + + return tx.Model(dest).Updates(assigns) + } - // assign dest return } @@ -307,3 +342,7 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx.callbacks.Raw().Execute(tx) return } + +func (db *DB) RecordNotFound() bool { + return errors.Is(db.Error, ErrRecordNotFound) +} diff --git a/statement.go b/statement.go index 51dea6fc..b110ac1b 100644 --- a/statement.go +++ b/statement.go @@ -203,6 +203,8 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} + } else if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} } } @@ -238,16 +240,24 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { switch reflectValue.Kind() { case reflect.Struct: - for _, field := range s.FieldsByDBName { + for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue); !isZero { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } } } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - for _, field := range s.FieldsByDBName { + for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } } } } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 728550d5..bd540620 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" . "github.com/jinzhu/gorm/tests" ) @@ -53,4 +54,59 @@ func TestFindOrInitialize(t *testing.T) { } func TestFindOrCreate(t *testing.T) { + var user1, user2, user3, user4, user5, user6, user7, user8 User + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) + if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) + if user1.ID != user2.ID || user2.Name != "find or create" || user2.ID == 0 || user2.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) + if user3.Name != "find or create 2" || user3.ID == 0 { + t.Errorf("user should be created with inline search value") + } + + DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) + if user4.Name != "find or create 3" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and attrs") + } + + updatedAt1 := user4.UpdatedAt + DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateAt should be changed when update values with assign") + } + + DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) + if user4.Name != "find or create 4" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or create" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) + if user6.Name != "find or create" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Find(&user7) + if user7.Name != "find or create" || user7.ID == 0 || user7.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8) + if DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).RecordNotFound() { + t.Errorf("has many association should be saved") + } + + if DB.Where("number = ?", "1231231231").First(&Account{}).RecordNotFound() { + t.Errorf("belongs to association should be saved") + } }