Support save slice of data
This commit is contained in:
		
							parent
							
								
									22ff8377df
								
							
						
					
					
						commit
						f3424c6864
					
				| @ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| // ConvertToCreateValues convert to create values
 | ||||
| func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { | ||||
| func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||
| 	switch value := stmt.Dest.(type) { | ||||
| 	case map[string]interface{}: | ||||
| 		return ConvertMapToValuesForCreate(stmt, value) | ||||
| 		values = ConvertMapToValuesForCreate(stmt, value) | ||||
| 	case []map[string]interface{}: | ||||
| 		return ConvertSliceOfMapToValuesForCreate(stmt, value) | ||||
| 		values = ConvertSliceOfMapToValuesForCreate(stmt, value) | ||||
| 	default: | ||||
| 		var ( | ||||
| 			values                    = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} | ||||
| 			selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) | ||||
| 			curTime                   = stmt.DB.NowFunc() | ||||
| 			isZero                    bool | ||||
| 		) | ||||
| 		values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} | ||||
| 
 | ||||
| 		for _, db := range stmt.Schema.DBNames { | ||||
| 			if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { | ||||
| @ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt.UpdatingColumn { | ||||
| 		if stmt.Schema != nil { | ||||
| 			columns := make([]string, 0, len(stmt.Schema.DBNames)-1) | ||||
| 			for _, name := range stmt.Schema.DBNames { | ||||
| 				if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 					if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { | ||||
| 						columns = append(columns, name) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			onConflict := clause.OnConflict{ | ||||
| 				Columns:   make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), | ||||
| 				DoUpdates: clause.AssignmentColumns(columns), | ||||
| 			} | ||||
| 
 | ||||
| 			for idx, field := range stmt.Schema.PrimaryFields { | ||||
| 				onConflict.Columns[idx] = clause.Column{Name: field.DBName} | ||||
| 			} | ||||
| 			stmt.AddClause(onConflict) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return values | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Dest = value | ||||
| 
 | ||||
| 	if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||
| 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | ||||
| 	reflectValue := reflect.Indirect(reflect.ValueOf(value)) | ||||
| 	switch reflectValue.Kind() { | ||||
| 	case reflect.Slice, reflect.Array: | ||||
| 			tx.AddError(ErrPtrStructSupported) | ||||
| 		tx.Statement.UpdatingColumn = true | ||||
| 		tx.callbacks.Create().Execute(tx) | ||||
| 	case reflect.Struct: | ||||
| 		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||
| 			where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | ||||
| 			for idx, pf := range tx.Statement.Schema.PrimaryFields { | ||||
| 				if pv, isZero := pf.ValueOf(reflectValue); isZero { | ||||
| 					tx.callbacks.Create().Execute(tx) | ||||
| @ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| 
 | ||||
| 			tx.Statement.AddClause(where) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 		fallthrough | ||||
| 	default: | ||||
| 		if len(tx.Statement.Selects) == 0 { | ||||
| 			tx.Statement.Selects = append(tx.Statement.Selects, "*") | ||||
| 		} | ||||
| 
 | ||||
| 		tx.callbacks.Update().Execute(tx) | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpsertWithSave(t *testing.T) { | ||||
| 	langs := []Language{ | ||||
| 		{Code: "upsert-save-1", Name: "Upsert-save-1"}, | ||||
| 		{Code: "upsert-save-2", Name: "Upsert-save-2"}, | ||||
| 	} | ||||
| 	if err := DB.Save(&langs).Error; err != nil { | ||||
| 		t.Errorf("Failed to create, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, lang := range langs { | ||||
| 		var result Language | ||||
| 		if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { | ||||
| 			t.Errorf("Failed to query lang, got error %v", err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindOrInitialize(t *testing.T) { | ||||
| 	var user1, user2, user3, user4, user5, user6 User | ||||
| 	if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu