Support save slice of data
This commit is contained in:
		
							parent
							
								
									22ff8377df
								
							
						
					
					
						commit
						f3424c6864
					
				| @ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ConvertToCreateValues convert to create values
 | // ConvertToCreateValues convert to create values
 | ||||||
| func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { | func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||||
| 	switch value := stmt.Dest.(type) { | 	switch value := stmt.Dest.(type) { | ||||||
| 	case map[string]interface{}: | 	case map[string]interface{}: | ||||||
| 		return ConvertMapToValuesForCreate(stmt, value) | 		values = ConvertMapToValuesForCreate(stmt, value) | ||||||
| 	case []map[string]interface{}: | 	case []map[string]interface{}: | ||||||
| 		return ConvertSliceOfMapToValuesForCreate(stmt, value) | 		values = ConvertSliceOfMapToValuesForCreate(stmt, value) | ||||||
| 	default: | 	default: | ||||||
| 		var ( | 		var ( | ||||||
| 			values                    = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} |  | ||||||
| 			selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) | 			selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) | ||||||
| 			curTime                   = stmt.DB.NowFunc() | 			curTime                   = stmt.DB.NowFunc() | ||||||
| 			isZero                    bool | 			isZero                    bool | ||||||
| 		) | 		) | ||||||
|  | 		values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} | ||||||
| 
 | 
 | ||||||
| 		for _, db := range stmt.Schema.DBNames { | 		for _, db := range stmt.Schema.DBNames { | ||||||
| 			if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { | 			if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { | ||||||
| @ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 |  | ||||||
| 		return values |  | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	if stmt.UpdatingColumn { | ||||||
|  | 		if stmt.Schema != nil { | ||||||
|  | 			columns := make([]string, 0, len(stmt.Schema.DBNames)-1) | ||||||
|  | 			for _, name := range stmt.Schema.DBNames { | ||||||
|  | 				if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
|  | 					if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { | ||||||
|  | 						columns = append(columns, name) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			onConflict := clause.OnConflict{ | ||||||
|  | 				Columns:   make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), | ||||||
|  | 				DoUpdates: clause.AssignmentColumns(columns), | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for idx, field := range stmt.Schema.PrimaryFields { | ||||||
|  | 				onConflict.Columns[idx] = clause.Column{Name: field.DBName} | ||||||
|  | 			} | ||||||
|  | 			stmt.AddClause(onConflict) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return values | ||||||
| } | } | ||||||
|  | |||||||
| @ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = value | 	tx.Statement.Dest = value | ||||||
| 
 | 
 | ||||||
| 	if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | 	reflectValue := reflect.Indirect(reflect.ValueOf(value)) | ||||||
| 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | 	switch reflectValue.Kind() { | ||||||
| 		reflectValue := reflect.Indirect(reflect.ValueOf(value)) | 	case reflect.Slice, reflect.Array: | ||||||
| 		switch reflectValue.Kind() { | 		tx.Statement.UpdatingColumn = true | ||||||
| 		case reflect.Slice, reflect.Array: | 		tx.callbacks.Create().Execute(tx) | ||||||
| 			tx.AddError(ErrPtrStructSupported) | 	case reflect.Struct: | ||||||
| 		case reflect.Struct: | 		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||||
|  | 			where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | ||||||
| 			for idx, pf := range tx.Statement.Schema.PrimaryFields { | 			for idx, pf := range tx.Statement.Schema.PrimaryFields { | ||||||
| 				if pv, isZero := pf.ValueOf(reflectValue); isZero { | 				if pv, isZero := pf.ValueOf(reflectValue); isZero { | ||||||
| 					tx.callbacks.Create().Execute(tx) | 					tx.callbacks.Create().Execute(tx) | ||||||
| @ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| 
 | 
 | ||||||
| 			tx.Statement.AddClause(where) | 			tx.Statement.AddClause(where) | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		fallthrough | ||||||
|  | 	default: | ||||||
|  | 		if len(tx.Statement.Selects) == 0 { | ||||||
|  | 			tx.Statement.Selects = append(tx.Statement.Selects, "*") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		tx.callbacks.Update().Execute(tx) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(tx.Statement.Selects) == 0 { |  | ||||||
| 		tx.Statement.Selects = append(tx.Statement.Selects, "*") |  | ||||||
| 	} |  | ||||||
| 	tx.callbacks.Update().Execute(tx) |  | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestUpsertWithSave(t *testing.T) { | ||||||
|  | 	langs := []Language{ | ||||||
|  | 		{Code: "upsert-save-1", Name: "Upsert-save-1"}, | ||||||
|  | 		{Code: "upsert-save-2", Name: "Upsert-save-2"}, | ||||||
|  | 	} | ||||||
|  | 	if err := DB.Save(&langs).Error; err != nil { | ||||||
|  | 		t.Errorf("Failed to create, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, lang := range langs { | ||||||
|  | 		var result Language | ||||||
|  | 		if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { | ||||||
|  | 			t.Errorf("Failed to query lang, got error %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestFindOrInitialize(t *testing.T) { | func TestFindOrInitialize(t *testing.T) { | ||||||
| 	var user1, user2, user3, user4, user5, user6 User | 	var user1, user2, user3, user4, user5, user6 User | ||||||
| 	if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { | 	if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu