map insert support return increment id (#6662)
This commit is contained in:
		
							parent
							
								
									c1e911f6ed
								
							
						
					
					
						commit
						3207ad6033
					
				| @ -103,9 +103,22 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 		if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||
| 			db.Statement.Schema.PrioritizedPrimaryField != nil && | ||||
| 			db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 		if db.RowsAffected == 0 { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		var ( | ||||
| 			pkField     *schema.Field | ||||
| 			pkFieldName = "@id" | ||||
| 		) | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 				return | ||||
| 			} | ||||
| 			pkField = db.Statement.Schema.PrioritizedPrimaryField | ||||
| 			pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		insertID, err := result.LastInsertId() | ||||
| 		insertOk := err == nil && insertID > 0 | ||||
| 		if !insertOk { | ||||
| @ -113,6 +126,33 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// append @id column with value for auto-increment primary key
 | ||||
| 		// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
 | ||||
| 		switch values := db.Statement.Dest.(type) { | ||||
| 		case map[string]interface{}: | ||||
| 			values[pkFieldName] = insertID | ||||
| 		case *map[string]interface{}: | ||||
| 			(*values)[pkFieldName] = insertID | ||||
| 		case []map[string]interface{}, *[]map[string]interface{}: | ||||
| 			mapValues, ok := values.([]map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				if v, ok := values.(*[]map[string]interface{}); ok { | ||||
| 					if *v != nil { | ||||
| 						mapValues = *v | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			for _, mapValue := range mapValues { | ||||
| 				if mapValue != nil { | ||||
| 					mapValue[pkFieldName] = insertID | ||||
| 				} | ||||
| 				insertID += schema.DefaultAutoIncrementIncrement | ||||
| 			} | ||||
| 		default: | ||||
| 			if pkField == nil { | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				if config.LastInsertIDReversed { | ||||
| @ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) | ||||
| 						_, isZero := pkField.ValueOf(db.Statement.Context, rv) | ||||
| 						if isZero { | ||||
| 							db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 							db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID -= pkField.AutoIncrementIncrement | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| @ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { | ||||
| 							db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 						if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { | ||||
| 							db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID += pkField.AutoIncrementIncrement | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) | ||||
| 				_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) | ||||
| 				if isZero { | ||||
| 					db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) | ||||
| 					db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -49,6 +49,8 @@ const ( | ||||
| 	Bytes  DataType = "bytes" | ||||
| ) | ||||
| 
 | ||||
| const DefaultAutoIncrementIncrement int64 = 1 | ||||
| 
 | ||||
| // Field is the representation of model schema's field
 | ||||
| type Field struct { | ||||
| 	Name                   string | ||||
| @ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 		NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), | ||||
| 		Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]), | ||||
| 		Comment:                tagSetting["COMMENT"], | ||||
| 		AutoIncrementIncrement: 1, | ||||
| 		AutoIncrementIncrement: DefaultAutoIncrementIncrement, | ||||
| 	} | ||||
| 
 | ||||
| 	for field.IndirectFieldType.Kind() == reflect.Ptr { | ||||
|  | ||||
| @ -2,6 +2,7 @@ package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| @ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateOnConfilctWithDefalutNull(t *testing.T) { | ||||
| func TestCreateOnConflictWithDefaultNull(t *testing.T) { | ||||
| 	type OnConfilctUser struct { | ||||
| 		ID     string | ||||
| 		Name   string `gorm:"default:null"` | ||||
| @ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) { | ||||
| 	AssertEqual(t, u2.Email, "on-confilct-user-email-2") | ||||
| 	AssertEqual(t, u2.Mobile, "133xxxx") | ||||
| } | ||||
| 
 | ||||
| func TestCreateFromMapWithoutPK(t *testing.T) { | ||||
| 	if !isMysql() { | ||||
| 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 1: one record, create from map[string]interface{}
 | ||||
| 	mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} | ||||
| 	if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := mapValue1["id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	var result1 User | ||||
| 	if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var idVal int64 | ||||
| 	_, ok := mapValue1["id"].(uint) | ||||
| 	if ok { | ||||
| 		t.Skipf("This test case skipped, because the db supports returning") | ||||
| 	} | ||||
| 
 | ||||
| 	idVal, ok = mapValue1["id"].(int64) | ||||
| 	if !ok { | ||||
| 		t.Fatal("ret result missing id") | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(result1.ID) != idVal { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case2: one record, create from *map[string]interface{}
 | ||||
| 	mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} | ||||
| 	if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := mapValue2["id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	_, ok = mapValue2["id"].(uint) | ||||
| 	if ok { | ||||
| 		t.Skipf("This test case skipped, because the db supports returning") | ||||
| 	} | ||||
| 
 | ||||
| 	idVal, ok = mapValue2["id"].(int64) | ||||
| 	if !ok { | ||||
| 		t.Fatal("ret result missing id") | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(result2.ID) != idVal { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 3: records
 | ||||
| 	values := []map[string]interface{}{ | ||||
| 		{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, | ||||
| 	} | ||||
| 
 | ||||
| 	beforeLen := len(values) | ||||
| 	if err := DB.Model(&User{}).Create(&values).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// mariadb with returning, values will be appended with id map
 | ||||
| 	if len(values) == beforeLen*2 { | ||||
| 		t.Skipf("This test case skipped, because the db supports returning") | ||||
| 	} | ||||
| 
 | ||||
| 	for i := range values { | ||||
| 		v, ok := values[i]["id"] | ||||
| 		if !ok { | ||||
| 			t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||
| 		} | ||||
| 
 | ||||
| 		var result User | ||||
| 		if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { | ||||
| 			t.Fatalf("failed to create from map, got error %v", err) | ||||
| 		} | ||||
| 		if int64(result.ID) != v.(int64) { | ||||
| 			t.Fatal("failed to create data from map with table, @id != id") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 	if !isMysql() { | ||||
| 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||
| 	} | ||||
| 	tableDB := DB.Table("`users`") | ||||
| 
 | ||||
| 	// case 1: create from map[string]interface{}
 | ||||
| 	record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} | ||||
| 	if err := tableDB.Create(record).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map with table, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := record["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	var res map[string]interface{} | ||||
| 	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res["id"].(uint64)) != record["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 2: create from *map[string]interface{}
 | ||||
| 	record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} | ||||
| 	tableDB2 := DB.Table("users") | ||||
| 	if err := tableDB2.Create(&record1).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 	if _, ok := record1["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	var res1 map[string]interface{} | ||||
| 	if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res1["id"].(uint64)) != record1["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 3: create from []map[string]interface{}
 | ||||
| 	records := []map[string]interface{}{ | ||||
| 		{"name": "create_from_map_with_table_2", "age": 19}, | ||||
| 		{"name": "create_from_map_with_table_3", "age": 20}, | ||||
| 	} | ||||
| 
 | ||||
| 	tableDB = DB.Table("users") | ||||
| 	if err := tableDB.Create(&records).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from slice of map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := records[0]["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := records[1]["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	var res2 map[string]interface{} | ||||
| 	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { | ||||
| 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var res3 map[string]interface{} | ||||
| 	if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { | ||||
| 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res2["id"].(uint64)) != records[0]["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res3["id"].(uint64)) != records[1]["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 FangSqing
						FangSqing