Merge branch 'master' into fix_scan_array

This commit is contained in:
Jinzhu 2022-09-22 15:46:29 +08:00 committed by GitHub
commit 3d57f8ecf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 39 deletions

22
scan.go
View File

@ -243,14 +243,20 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var elem reflect.Value
recyclableStruct := reflect.New(reflectValueType)
isArrayKind := reflectValue.Kind() == reflect.Array
var (
elem reflect.Value
recyclableStruct = reflect.New(reflectValueType)
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 {
update = false
if !isArrayKind {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
@ -285,12 +291,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if !isPtr {
elem = elem.Elem()
}
if !isArrayKind {
reflectValue = reflect.Append(reflectValue, elem)
} else {
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else {
reflectValue = reflect.Append(reflectValue, elem)
}
}
}
@ -314,4 +320,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound)
}
}
}

View File

@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many
ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
)
@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
joinFieldName = strings.Title(joinForeignKeys[idx])
}
ownFieldsMap[joinFieldName] = true
ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
for idx, relField := range refForeignFields {
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
@ -254,14 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
}
joinTableFields = append(joinTableFields, reflect.StructField{
@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if ownPrimaryField {
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
PrimaryKey: of,
ForeignKey: f,
})
} else {
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPrimaryField,
})
}
}
}

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema")
t.Errorf("Failed to parse schema, got error %v", err)
} else {
for _, rel := range relations {
checkSchemaRelation(t, s, rel)
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
})
}
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct {
gorm.Model

View File

@ -3,17 +3,18 @@ module gorm.io/gorm/tests
go 1.16
require (
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.6
github.com/mattn/go-sqlite3 v1.14.14 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
gorm.io/driver/mysql v1.3.5
gorm.io/driver/postgres v1.3.8
github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.15 // indirect
golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect
gorm.io/driver/mysql v1.3.6
gorm.io/driver/postgres v1.3.10
gorm.io/driver/sqlite v1.3.6
gorm.io/driver/sqlserver v1.3.2
gorm.io/gorm v1.23.8
gorm.io/gorm v1.23.9
)
replace gorm.io/gorm => ../