map insert support return increment id (#6662)
This commit is contained in:
		
							parent
							
								
									c1e911f6ed
								
							
						
					
					
						commit
						3207ad6033
					
				| @ -103,13 +103,53 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		db.RowsAffected, _ = result.RowsAffected() | 		db.RowsAffected, _ = result.RowsAffected() | ||||||
| 		if db.RowsAffected != 0 && db.Statement.Schema != nil && | 		if db.RowsAffected == 0 { | ||||||
| 			db.Statement.Schema.PrioritizedPrimaryField != nil && | 			return | ||||||
| 			db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | 		} | ||||||
| 			insertID, err := result.LastInsertId() | 
 | ||||||
| 			insertOk := err == nil && insertID > 0 | 		var ( | ||||||
| 			if !insertOk { | 			pkField     *schema.Field | ||||||
| 				db.AddError(err) | 			pkFieldName = "@id" | ||||||
|  | 		) | ||||||
|  | 		if db.Statement.Schema != nil { | ||||||
|  | 			if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			pkField = db.Statement.Schema.PrioritizedPrimaryField | ||||||
|  | 			pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		insertID, err := result.LastInsertId() | ||||||
|  | 		insertOk := err == nil && insertID > 0 | ||||||
|  | 		if !insertOk { | ||||||
|  | 			db.AddError(err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// append @id column with value for auto-increment primary key
 | ||||||
|  | 		// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
 | ||||||
|  | 		switch values := db.Statement.Dest.(type) { | ||||||
|  | 		case map[string]interface{}: | ||||||
|  | 			values[pkFieldName] = insertID | ||||||
|  | 		case *map[string]interface{}: | ||||||
|  | 			(*values)[pkFieldName] = insertID | ||||||
|  | 		case []map[string]interface{}, *[]map[string]interface{}: | ||||||
|  | 			mapValues, ok := values.([]map[string]interface{}) | ||||||
|  | 			if !ok { | ||||||
|  | 				if v, ok := values.(*[]map[string]interface{}); ok { | ||||||
|  | 					if *v != nil { | ||||||
|  | 						mapValues = *v | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			for _, mapValue := range mapValues { | ||||||
|  | 				if mapValue != nil { | ||||||
|  | 					mapValue[pkFieldName] = insertID | ||||||
|  | 				} | ||||||
|  | 				insertID += schema.DefaultAutoIncrementIncrement | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			if pkField == nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| @ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 							break | 							break | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) | 						_, isZero := pkField.ValueOf(db.Statement.Context, rv) | ||||||
| 						if isZero { | 						if isZero { | ||||||
| 							db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) | 							db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) | ||||||
| 							insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | 							insertID -= pkField.AutoIncrementIncrement | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} else { | 				} else { | ||||||
| @ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 							break | 							break | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { | 						if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { | ||||||
| 							db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) | 							db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) | ||||||
| 							insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | 							insertID += pkField.AutoIncrementIncrement | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
| 				_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) | 				_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) | ||||||
| 				if isZero { | 				if isZero { | ||||||
| 					db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) | 					db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -49,6 +49,8 @@ const ( | |||||||
| 	Bytes  DataType = "bytes" | 	Bytes  DataType = "bytes" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | const DefaultAutoIncrementIncrement int64 = 1 | ||||||
|  | 
 | ||||||
| // Field is the representation of model schema's field
 | // Field is the representation of model schema's field
 | ||||||
| type Field struct { | type Field struct { | ||||||
| 	Name                   string | 	Name                   string | ||||||
| @ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 		NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), | 		NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), | ||||||
| 		Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]), | 		Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]), | ||||||
| 		Comment:                tagSetting["COMMENT"], | 		Comment:                tagSetting["COMMENT"], | ||||||
| 		AutoIncrementIncrement: 1, | 		AutoIncrementIncrement: DefaultAutoIncrementIncrement, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for field.IndirectFieldType.Kind() == reflect.Ptr { | 	for field.IndirectFieldType.Kind() == reflect.Ptr { | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ package tests_test | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| @ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestCreateOnConfilctWithDefalutNull(t *testing.T) { | func TestCreateOnConflictWithDefaultNull(t *testing.T) { | ||||||
| 	type OnConfilctUser struct { | 	type OnConfilctUser struct { | ||||||
| 		ID     string | 		ID     string | ||||||
| 		Name   string `gorm:"default:null"` | 		Name   string `gorm:"default:null"` | ||||||
| @ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) { | |||||||
| 	AssertEqual(t, u2.Email, "on-confilct-user-email-2") | 	AssertEqual(t, u2.Email, "on-confilct-user-email-2") | ||||||
| 	AssertEqual(t, u2.Mobile, "133xxxx") | 	AssertEqual(t, u2.Mobile, "133xxxx") | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestCreateFromMapWithoutPK(t *testing.T) { | ||||||
|  | 	if !isMysql() { | ||||||
|  | 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// case 1: one record, create from map[string]interface{}
 | ||||||
|  | 	mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} | ||||||
|  | 	if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, ok := mapValue1["id"]; !ok { | ||||||
|  | 		t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var result1 User | ||||||
|  | 	if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { | ||||||
|  | 		t.Fatalf("failed to create from map, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var idVal int64 | ||||||
|  | 	_, ok := mapValue1["id"].(uint) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Skipf("This test case skipped, because the db supports returning") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	idVal, ok = mapValue1["id"].(int64) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal("ret result missing id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int64(result1.ID) != idVal { | ||||||
|  | 		t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// case2: one record, create from *map[string]interface{}
 | ||||||
|  | 	mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} | ||||||
|  | 	if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, ok := mapValue2["id"]; !ok { | ||||||
|  | 		t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var result2 User | ||||||
|  | 	if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { | ||||||
|  | 		t.Fatalf("failed to create from map, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, ok = mapValue2["id"].(uint) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Skipf("This test case skipped, because the db supports returning") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	idVal, ok = mapValue2["id"].(int64) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal("ret result missing id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int64(result2.ID) != idVal { | ||||||
|  | 		t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// case 3: records
 | ||||||
|  | 	values := []map[string]interface{}{ | ||||||
|  | 		{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	beforeLen := len(values) | ||||||
|  | 	if err := DB.Model(&User{}).Create(&values).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// mariadb with returning, values will be appended with id map
 | ||||||
|  | 	if len(values) == beforeLen*2 { | ||||||
|  | 		t.Skipf("This test case skipped, because the db supports returning") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for i := range values { | ||||||
|  | 		v, ok := values[i]["id"] | ||||||
|  | 		if !ok { | ||||||
|  | 			t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var result User | ||||||
|  | 		if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { | ||||||
|  | 			t.Fatalf("failed to create from map, got error %v", err) | ||||||
|  | 		} | ||||||
|  | 		if int64(result.ID) != v.(int64) { | ||||||
|  | 			t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCreateFromMapWithTable(t *testing.T) { | ||||||
|  | 	if !isMysql() { | ||||||
|  | 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||||
|  | 	} | ||||||
|  | 	tableDB := DB.Table("`users`") | ||||||
|  | 
 | ||||||
|  | 	// case 1: create from map[string]interface{}
 | ||||||
|  | 	record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} | ||||||
|  | 	if err := tableDB.Create(record).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create data from map with table, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, ok := record["@id"]; !ok { | ||||||
|  | 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var res map[string]interface{} | ||||||
|  | 	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { | ||||||
|  | 		t.Fatalf("failed to create from map, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int64(res["id"].(uint64)) != record["@id"] { | ||||||
|  | 		t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// case 2: create from *map[string]interface{}
 | ||||||
|  | 	record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} | ||||||
|  | 	tableDB2 := DB.Table("users") | ||||||
|  | 	if err := tableDB2.Create(&record1).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if _, ok := record1["@id"]; !ok { | ||||||
|  | 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var res1 map[string]interface{} | ||||||
|  | 	if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { | ||||||
|  | 		t.Fatalf("failed to create from map, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int64(res1["id"].(uint64)) != record1["@id"] { | ||||||
|  | 		t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// case 3: create from []map[string]interface{}
 | ||||||
|  | 	records := []map[string]interface{}{ | ||||||
|  | 		{"name": "create_from_map_with_table_2", "age": 19}, | ||||||
|  | 		{"name": "create_from_map_with_table_3", "age": 20}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tableDB = DB.Table("users") | ||||||
|  | 	if err := tableDB.Create(&records).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create data from slice of map, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, ok := records[0]["@id"]; !ok { | ||||||
|  | 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, ok := records[1]["@id"]; !ok { | ||||||
|  | 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var res2 map[string]interface{} | ||||||
|  | 	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { | ||||||
|  | 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var res3 map[string]interface{} | ||||||
|  | 	if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { | ||||||
|  | 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int64(res2["id"].(uint64)) != records[0]["@id"] { | ||||||
|  | 		t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int64(res3["id"].(uint64)) != records[1]["@id"] { | ||||||
|  | 		t.Fatal("failed to create data from map with table, @id != id") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 FangSqing
						FangSqing