Add FirstOrInit support
This commit is contained in:
		
							parent
							
								
									5be642a435
								
							
						
					
					
						commit
						ac8708b500
					
				| @ -224,13 +224,15 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Assign(attrs ...interface{}) (tx *DB) { | ||||
| func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.attrs = attrs | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { | ||||
| func (db *DB) Assign(attrs ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.assigns = attrs | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -171,14 +171,3 @@ func (like Like) NegationBuild(builder Builder) { | ||||
| 	builder.WriteString(" NOT LIKE ") | ||||
| 	builder.AddVar(builder, like.Value) | ||||
| } | ||||
| 
 | ||||
| // Map
 | ||||
| type Map map[interface{}]interface{} | ||||
| 
 | ||||
| func (m Map) Build(builder Builder) { | ||||
| 	// TODO
 | ||||
| } | ||||
| 
 | ||||
| func (m Map) NegationBuild(builder Builder) { | ||||
| 	// TODO
 | ||||
| } | ||||
|  | ||||
| @ -2,6 +2,7 @@ package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 
 | ||||
| @ -99,13 +100,56 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) { | ||||
| func (tx *DB) assignExprsToValue(exprs []clause.Expression) { | ||||
| 	for _, expr := range exprs { | ||||
| 		if eq, ok := expr.(clause.Eq); ok { | ||||
| 			switch column := eq.Column.(type) { | ||||
| 			case string: | ||||
| 				if field := tx.Statement.Schema.LookUpField(column); field != nil { | ||||
| 					field.Set(tx.Statement.ReflectValue, eq.Value) | ||||
| 				} | ||||
| 			case clause.Column: | ||||
| 				if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { | ||||
| 					field.Set(tx.Statement.ReflectValue, eq.Value) | ||||
| 				} | ||||
| 			default: | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { | ||||
| 		if c, ok := tx.Statement.Clauses["WHERE"]; ok { | ||||
| 			if where, ok := c.Expression.(clause.Where); ok { | ||||
| 				tx.assignExprsToValue(where.Exprs) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// initialize with attrs, conds
 | ||||
| 		if len(tx.Statement.attrs) > 0 { | ||||
| 			exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]) | ||||
| 			tx.assignExprsToValue(exprs) | ||||
| 		} | ||||
| 		tx.Error = nil | ||||
| 	} | ||||
| 
 | ||||
| 	// initialize with attrs, conds
 | ||||
| 	if len(tx.Statement.assigns) > 0 { | ||||
| 		exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) | ||||
| 		tx.assignExprsToValue(exprs) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	// if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) {
 | ||||
| 	// 	// initialize with attrs, conds
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	// assign dest
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										49
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								statement.go
									
									
									
									
									
								
							| @ -34,6 +34,8 @@ type Statement struct { | ||||
| 	SQL                  strings.Builder | ||||
| 	Vars                 []interface{} | ||||
| 	NamedVars            []sql.NamedArg | ||||
| 	attrs                []interface{} | ||||
| 	assigns              []interface{} | ||||
| } | ||||
| 
 | ||||
| // StatementModifier statement modifier interface
 | ||||
| @ -195,7 +197,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { | ||||
| } | ||||
| 
 | ||||
| // BuildCondtion build condition
 | ||||
| func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { | ||||
| func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) { | ||||
| 	if sql, ok := query.(string); ok { | ||||
| 		if i, err := strconv.Atoi(sql); err == nil { | ||||
| 			query = i | ||||
| @ -212,42 +214,53 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con | ||||
| 
 | ||||
| 		switch v := arg.(type) { | ||||
| 		case clause.Expression: | ||||
| 			conditions = append(conditions, v) | ||||
| 			conds = append(conds, v) | ||||
| 		case *DB: | ||||
| 			if v.Statement == nil { | ||||
| 				if cs, ok := v.Statement.Clauses["WHERE"]; ok { | ||||
| 					conditions = append(conditions, cs.Expression) | ||||
| 					conds = append(conds, cs.Expression) | ||||
| 				} | ||||
| 			} | ||||
| 		case map[interface{}]interface{}: | ||||
| 			var clauseMap = clause.Map{} | ||||
| 			for i, j := range v { | ||||
| 				clauseMap[i] = j | ||||
| 				conds = append(conds, clause.Eq{Column: i, Value: j}) | ||||
| 			} | ||||
| 			conditions = append(conditions, clauseMap) | ||||
| 		case map[string]string: | ||||
| 			var clauseMap = clause.Map{} | ||||
| 			for i, j := range v { | ||||
| 				clauseMap[i] = j | ||||
| 				conds = append(conds, clause.Eq{Column: i, Value: j}) | ||||
| 			} | ||||
| 			conditions = append(conditions, clauseMap) | ||||
| 		case map[string]interface{}: | ||||
| 			var clauseMap = clause.Map{} | ||||
| 			for i, j := range v { | ||||
| 				clauseMap[i] = j | ||||
| 				conds = append(conds, clause.Eq{Column: i, Value: j}) | ||||
| 			} | ||||
| 			conditions = append(conditions, clauseMap) | ||||
| 		default: | ||||
| 			// TODO check is struct
 | ||||
| 			// struct, slice -> ids
 | ||||
| 			reflectValue := reflect.Indirect(reflect.ValueOf(arg)) | ||||
| 			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { | ||||
| 				switch reflectValue.Kind() { | ||||
| 				case reflect.Struct: | ||||
| 					for _, field := range s.FieldsByDBName { | ||||
| 						if v, isZero := field.ValueOf(reflectValue); !isZero { | ||||
| 							conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) | ||||
| 						} | ||||
| 					} | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for i := 0; i < reflectValue.Len(); i++ { | ||||
| 						for _, field := range s.FieldsByDBName { | ||||
| 							if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { | ||||
| 								conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len(conditions) == 0 { | ||||
| 		conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args}) | ||||
| 	if len(conds) == 0 { | ||||
| 		conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) | ||||
| 	} | ||||
| 
 | ||||
| 	return conditions | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Build build sql with clauses names
 | ||||
| @ -337,7 +350,7 @@ func (stmt *Statement) reinit() { | ||||
| 	// 	return true
 | ||||
| 	// })
 | ||||
| 
 | ||||
| 	stmt.Schema = nil | ||||
| 	// stmt.Schema = nil
 | ||||
| 	stmt.SQL.Reset() | ||||
| 	stmt.Vars = nil | ||||
| 	stmt.NamedVars = nil | ||||
|  | ||||
							
								
								
									
										56
									
								
								tests/upsert_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								tests/upsert_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,56 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestFindOrInitialize(t *testing.T) { | ||||
| 	var user1, user2, user3, user4, user5, user6 User | ||||
| 	if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { | ||||
| 		t.Errorf("no error should happen when FirstOrInit, but got %v", err) | ||||
| 	} | ||||
| 	if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { | ||||
| 		t.Errorf("user should be initialized with search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) | ||||
| 	if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 { | ||||
| 		t.Errorf("user should be initialized with search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) | ||||
| 	if user3.Name != "find or init 2" || user3.ID != 0 { | ||||
| 		t.Errorf("user should be initialized with inline search value") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) | ||||
| 	if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { | ||||
| 		t.Errorf("user should be initialized with search value and attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) | ||||
| 	if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { | ||||
| 		t.Errorf("user should be initialized with search value and assign attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Save(&User{Name: "find or init", Age: 33}) | ||||
| 	DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) | ||||
| 	if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 { | ||||
| 		t.Errorf("user should be found and not initialized by Attrs") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) | ||||
| 	if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 { | ||||
| 		t.Errorf("user should be found with FirstOrInit") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) | ||||
| 	if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 { | ||||
| 		t.Errorf("user should be found and updated with assigned attrs") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindOrCreate(t *testing.T) { | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu