Merge branch 'master' into embedded_preload
This commit is contained in:
commit
f2f58750f5
@ -3,6 +3,7 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -302,7 +303,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
for _, column := range values.Columns {
|
for _, column := range values.Columns {
|
||||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||||
|
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||||
if field.AutoUpdateTime > 0 {
|
if field.AutoUpdateTime > 0 {
|
||||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||||
switch field.AutoUpdateTime {
|
switch field.AutoUpdateTime {
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(db *gorm.DB) {
|
func Query(db *gorm.DB) {
|
||||||
@ -109,13 +111,46 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
specifiedRelationsName := make(map[string]interface{})
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema != nil {
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
var isRelations bool // is relations or raw sql
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
var relations []*schema.Relationship
|
||||||
})
|
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
if ok {
|
||||||
|
isRelations = true
|
||||||
|
relations = append(relations, relation)
|
||||||
|
} else {
|
||||||
|
// handle nested join like "Manager.Company"
|
||||||
|
nestedJoinNames := strings.Split(join.Name, ".")
|
||||||
|
if len(nestedJoinNames) > 1 {
|
||||||
|
isNestedJoin := true
|
||||||
|
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||||
|
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||||
|
for _, relname := range nestedJoinNames {
|
||||||
|
// incomplete match, only treated as raw sql
|
||||||
|
if relation, ok = currentRelations[relname]; ok {
|
||||||
|
gussNestedRelations = append(gussNestedRelations, relation)
|
||||||
|
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||||
|
} else {
|
||||||
|
isNestedJoin = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNestedJoin {
|
||||||
|
isRelations = true
|
||||||
|
relations = gussNestedRelations
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isRelations {
|
||||||
|
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||||
tableAliasName := relation.Name
|
tableAliasName := relation.Name
|
||||||
|
if parentTableName != clause.CurrentTable {
|
||||||
|
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||||
|
}
|
||||||
|
|
||||||
columnStmt := gorm.Statement{
|
columnStmt := gorm.Statement{
|
||||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||||
@ -128,7 +163,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
Table: tableAliasName,
|
Table: tableAliasName,
|
||||||
Name: s,
|
Name: s,
|
||||||
Alias: tableAliasName + "__" + s,
|
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -137,13 +172,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
for idx, ref := range relation.References {
|
for idx, ref := range relation.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
exprs[idx] = clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ref.PrimaryValue == "" {
|
if ref.PrimaryValue == "" {
|
||||||
exprs[idx] = clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -184,11 +219,28 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
return clause.Join{
|
||||||
Type: join.JoinType,
|
Type: joinType,
|
||||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||||
ON: clause.Where{Exprs: exprs},
|
ON: clause.Where{Exprs: exprs},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parentTableName := clause.CurrentTable
|
||||||
|
for _, rel := range relations {
|
||||||
|
// joins table alias like "Manager, Company, Manager__Company"
|
||||||
|
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||||
|
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||||
|
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||||
|
specifiedRelationsName[nestedAlias] = nil
|
||||||
|
}
|
||||||
|
parentTableName = rel.Name
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
|
@ -245,11 +245,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
updatingSchema := stmt.Schema
|
updatingSchema := stmt.Schema
|
||||||
|
var isDiffSchema bool
|
||||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
// different schema
|
// different schema
|
||||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||||
updatingSchema = updatingStmt.Schema
|
updatingSchema = updatingStmt.Schema
|
||||||
|
isDiffSchema = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,7 +278,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
|
|
||||||
if (ok || !isZero) && field.Updatable {
|
if (ok || !isZero) && field.Updatable {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||||
assignValue(field, value)
|
assignField := field
|
||||||
|
if isDiffSchema {
|
||||||
|
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||||
|
assignField = originField
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assignValue(assignField, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
@ -608,6 +610,15 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
|||||||
return fc(tx)
|
return fc(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
savepointIdx int64
|
||||||
|
savepointNamePool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||||
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||||
// they are rolled back.
|
// they are rolled back.
|
||||||
@ -617,7 +628,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
// nested transaction
|
// nested transaction
|
||||||
if !db.DisableNestedTransaction {
|
if !db.DisableNestedTransaction {
|
||||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
poolName := savepointNamePool.Get()
|
||||||
|
defer savepointNamePool.Put(poolName)
|
||||||
|
err = db.SavePoint(poolName.(string)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -625,7 +638,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
defer func() {
|
defer func() {
|
||||||
// Make sure to rollback when panic, Block error or Commit error
|
// Make sure to rollback when panic, Block error or Commit error
|
||||||
if panicked || err != nil {
|
if panicked || err != nil {
|
||||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
db.RollbackTo(poolName.(string))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
4
gorm.go
4
gorm.go
@ -347,15 +347,17 @@ func (db *DB) Callback() *callbacks {
|
|||||||
|
|
||||||
// AddError add error to db
|
// AddError add error to db
|
||||||
func (db *DB) AddError(err error) error {
|
func (db *DB) AddError(err error) error {
|
||||||
|
if err != nil {
|
||||||
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
||||||
err = errTranslator.Translate(err)
|
err = errTranslator.Translate(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
db.Error = err
|
db.Error = err
|
||||||
} else if err != nil {
|
} else {
|
||||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return db.Error
|
return db.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
60
scan.go
60
scan.go
@ -4,10 +4,10 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// prepareValues prepare values slice
|
// prepareValues prepare values slice
|
||||||
@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field != nil {
|
||||||
values[idx] = field.NewValuePool.Get()
|
values[idx] = field.NewValuePool.Get()
|
||||||
@ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
|||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
joinedSchemaMap := make(map[*schema.Field]interface{})
|
joinedNestedSchemaMap := make(map[string]interface{})
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field == nil {
|
if field == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||||
} else {
|
} else { // joinFields count is larger than 2 when using join
|
||||||
joinSchema := joinFields[idx][0]
|
var isNilPtrValue bool
|
||||||
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue)
|
var relValue reflect.Value
|
||||||
|
// does not contain raw dbname
|
||||||
|
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||||
|
// current reflect value
|
||||||
|
currentReflectValue := reflectValue
|
||||||
|
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||||
|
for _, joinSchema := range nestedJoinSchemas {
|
||||||
|
fullRels = append(fullRels, joinSchema.Name)
|
||||||
|
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||||
if relValue.Kind() == reflect.Ptr {
|
if relValue.Kind() == reflect.Ptr {
|
||||||
if _, ok := joinedSchemaMap[joinSchema]; !ok {
|
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||||
|
// same nested structure
|
||||||
|
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||||
continue
|
isNilPtrValue = true
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
joinedSchemaMap[joinSchema] = nil
|
joinedNestedSchemaMap[fullRelsName] = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
|
currentReflectValue = relValue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isNilPtrValue { // ignore if value is nil
|
||||||
|
f := joinFields[idx][len(joinFields[idx])-1]
|
||||||
|
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// release data to pool
|
// release data to pool
|
||||||
@ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
fields = make([]*schema.Field, len(columns))
|
fields = make([]*schema.Field, len(columns))
|
||||||
joinFields [][2]*schema.Field
|
joinFields [][]*schema.Field
|
||||||
sch = db.Statement.Schema
|
sch = db.Statement.Schema
|
||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = db.Statement.ReflectValue
|
||||||
)
|
)
|
||||||
@ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
} else {
|
} else {
|
||||||
matchedFieldCount[column] = 1
|
matchedFieldCount[column] = 1
|
||||||
}
|
}
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
subNameCount := len(names)
|
||||||
|
// nested relation fields
|
||||||
|
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||||
|
relFields = append(relFields, rel.Field)
|
||||||
|
for _, name := range names[1 : subNameCount-1] {
|
||||||
|
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||||
|
relFields = append(relFields, rel.Field)
|
||||||
|
}
|
||||||
|
// lastest name is raw dbname
|
||||||
|
dbName := names[subNameCount-1]
|
||||||
|
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
|
|
||||||
if len(joinFields) == 0 {
|
if len(joinFields) == 0 {
|
||||||
joinFields = make([][2]*schema.Field, len(columns))
|
joinFields = make([][]*schema.Field, len(columns))
|
||||||
}
|
}
|
||||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
relFields = append(relFields, field)
|
||||||
|
joinFields[idx] = relFields
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -920,6 +920,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||||
|
serializerType := serializerValue.Type()
|
||||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
if s, ok := v.(*serializer); ok {
|
if s, ok := v.(*serializer); ok {
|
||||||
if s.fieldValue != nil {
|
if s.fieldValue != nil {
|
||||||
@ -927,11 +929,12 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
||||||
if sameElemType {
|
if sameElemType {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
||||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
|
||||||
} else if sameType {
|
} else if sameType {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
||||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
|
||||||
}
|
}
|
||||||
|
si := reflect.New(serializerType)
|
||||||
|
si.Elem().Set(serializerValue)
|
||||||
|
s.Serializer = si.Interface().(SerializerInterface)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = oldFieldSetter(ctx, value, v)
|
err = oldFieldSetter(ctx, value, v)
|
||||||
@ -943,11 +946,15 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
|
|
||||||
func (field *Field) setupNewValuePool() {
|
func (field *Field) setupNewValuePool() {
|
||||||
if field.Serializer != nil {
|
if field.Serializer != nil {
|
||||||
|
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||||
|
serializerType := serializerValue.Type()
|
||||||
field.NewValuePool = &sync.Pool{
|
field.NewValuePool = &sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
|
si := reflect.New(serializerType)
|
||||||
|
si.Elem().Set(serializerValue)
|
||||||
return &serializer{
|
return &serializer{
|
||||||
Field: field,
|
Field: field,
|
||||||
Serializer: field.Serializer,
|
Serializer: si.Interface().(SerializerInterface),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -250,8 +250,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
|
if schema.PrioritizedPrimaryField == nil {
|
||||||
|
if len(schema.PrimaryFields) == 1 {
|
||||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||||
|
} else if len(schema.PrimaryFields) > 1 {
|
||||||
|
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
|
||||||
|
for _, field := range schema.PrimaryFields {
|
||||||
|
if field.AutoIncrement {
|
||||||
|
schema.PrioritizedPrimaryField = field
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range schema.PrimaryFields {
|
for _, field := range schema.PrimaryFields {
|
||||||
|
@ -293,3 +293,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
|
||||||
|
type Product struct {
|
||||||
|
ProductID uint `gorm:"primaryKey;autoIncrement"`
|
||||||
|
LanguageCode uint `gorm:"primaryKey"`
|
||||||
|
Code string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type ProductNonAutoIncrement struct {
|
||||||
|
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
|
||||||
|
LanguageCode uint `gorm:"primaryKey"`
|
||||||
|
Code string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prioritizedPrimaryField := schema.Field{
|
||||||
|
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
|
||||||
|
}
|
||||||
|
|
||||||
|
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
|
||||||
|
|
||||||
|
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
|
||||||
|
f.Creatable = true
|
||||||
|
f.Updatable = true
|
||||||
|
f.Readable = true
|
||||||
|
})
|
||||||
|
|
||||||
|
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
|
||||||
|
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -547,3 +547,68 @@ func TestFirstOrCreateRowsAffected(t *testing.T) {
|
|||||||
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
|
||||||
|
type CompositeKeyProduct struct {
|
||||||
|
ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key
|
||||||
|
LanguageCode int `gorm:"primaryKey;"` // primary key
|
||||||
|
Code string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prod := &CompositeKeyProduct{
|
||||||
|
LanguageCode: 56,
|
||||||
|
Code: "Code56",
|
||||||
|
Name: "ProductName56",
|
||||||
|
}
|
||||||
|
if err := DB.Create(&prod).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newProd := &CompositeKeyProduct{}
|
||||||
|
if err := DB.First(&newProd).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
|
||||||
|
type OnConfilctUser struct {
|
||||||
|
ID string
|
||||||
|
Name string `gorm:"default:null"`
|
||||||
|
Email string
|
||||||
|
Mobile string `gorm:"default:'133xxxx'"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := DB.Migrator().DropTable(&OnConfilctUser{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
err = DB.AutoMigrate(&OnConfilctUser{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
u := OnConfilctUser{
|
||||||
|
ID: "on-confilct-user-id",
|
||||||
|
Name: "on-confilct-user-name",
|
||||||
|
Email: "on-confilct-user-email",
|
||||||
|
Mobile: "on-confilct-user-mobile",
|
||||||
|
}
|
||||||
|
err = DB.Create(&u).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
u.Name = "on-confilct-user-name-2"
|
||||||
|
u.Email = "on-confilct-user-email-2"
|
||||||
|
u.Mobile = ""
|
||||||
|
err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
var u2 OnConfilctUser
|
||||||
|
err = DB.Where("id = ?", u.ID).First(&u2).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
AssertEqual(t, u2.Name, "on-confilct-user-name-2")
|
||||||
|
AssertEqual(t, u2.Email, "on-confilct-user-email-2")
|
||||||
|
AssertEqual(t, u2.Mobile, "133xxxx")
|
||||||
|
}
|
||||||
|
@ -9,11 +9,12 @@ require (
|
|||||||
github.com/lib/pq v1.10.7
|
github.com/lib/pq v1.10.7
|
||||||
github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
||||||
github.com/microsoft/go-mssqldb v0.20.0 // indirect
|
github.com/microsoft/go-mssqldb v0.20.0 // indirect
|
||||||
|
golang.org/x/crypto v0.7.0 // indirect
|
||||||
gorm.io/driver/mysql v1.4.7
|
gorm.io/driver/mysql v1.4.7
|
||||||
gorm.io/driver/postgres v1.4.8
|
gorm.io/driver/postgres v1.5.0
|
||||||
gorm.io/driver/sqlite v1.4.4
|
gorm.io/driver/sqlite v1.4.4
|
||||||
gorm.io/driver/sqlserver v1.4.2
|
gorm.io/driver/sqlserver v1.4.2
|
||||||
gorm.io/gorm v1.24.5
|
gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
@ -325,3 +325,66 @@ func TestJoinArgsWithDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
AssertEqual(t, user4.NamedPet.Name, "")
|
AssertEqual(t, user4.NamedPet.Name, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNestedJoins(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{
|
||||||
|
Name: "nested-joins-1",
|
||||||
|
Manager: GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}),
|
||||||
|
NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "nested-joins-2",
|
||||||
|
Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||||
|
NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var userIDs []uint
|
||||||
|
for _, user := range users {
|
||||||
|
userIDs = append(userIDs, user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users2 []User
|
||||||
|
if err := DB.
|
||||||
|
Joins("Manager").
|
||||||
|
Joins("Manager.Company").
|
||||||
|
Joins("Manager.NamedPet").
|
||||||
|
Joins("NamedPet").
|
||||||
|
Joins("NamedPet.Toy").
|
||||||
|
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
|
} else if len(users2) != len(users) {
|
||||||
|
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(users2, func(i, j int) bool {
|
||||||
|
return users2[i].ID > users2[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(users, func(i, j int) bool {
|
||||||
|
return users[i].ID > users[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
for idx, user := range users {
|
||||||
|
// user
|
||||||
|
CheckUser(t, user, users2[idx])
|
||||||
|
if users2[idx].Manager == nil {
|
||||||
|
t.Fatalf("Failed to load Manager")
|
||||||
|
}
|
||||||
|
// manager
|
||||||
|
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||||
|
// user pet
|
||||||
|
if users2[idx].NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||||
|
// manager pet
|
||||||
|
if users2[idx].Manager.NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1542,3 +1542,59 @@ func TestMigrateView(t *testing.T) {
|
|||||||
t.Fatalf("Failed to drop view, got %v", err)
|
t.Fatalf("Failed to drop view, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrateExistingBoolColumnPG(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type ColumnStruct struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
StringBool string
|
||||||
|
SmallintBool int `gorm:"type:smallint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ColumnStruct2 struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
StringBool bool // change existing boolean column from string to boolean
|
||||||
|
SmallintBool bool // change existing boolean column from smallint or other to boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&ColumnStruct{})
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&ColumnStruct{}); err != nil {
|
||||||
|
t.Errorf("Failed to migrate, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
|
||||||
|
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
|
||||||
|
t.Fatalf("no error should returns for ColumnTypes")
|
||||||
|
} else {
|
||||||
|
stmt := &gorm.Statement{DB: DB}
|
||||||
|
stmt.Parse(&ColumnStruct2{})
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
switch columnType.Name() {
|
||||||
|
case "id":
|
||||||
|
if v, ok := columnType.PrimaryKey(); !ok || !v {
|
||||||
|
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
case "string_bool":
|
||||||
|
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||||
|
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||||
|
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||||
|
}
|
||||||
|
case "smallint_bool":
|
||||||
|
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||||
|
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||||
|
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -773,3 +773,16 @@ func TestUpdateReturning(t *testing.T) {
|
|||||||
t.Errorf("failed to return updated age column")
|
t.Errorf("failed to return updated age column")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateWithDiffSchema(t *testing.T) {
|
||||||
|
user := GetUser("update-diff-schema-1", Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
type UserTemp struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
err := DB.Model(&user).Updates(&UserTemp{Name: "update-diff-schema-2"}).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
AssertEqual(t, "update-diff-schema-2", user.Name)
|
||||||
|
}
|
||||||
|
@ -13,8 +13,14 @@ import (
|
|||||||
|
|
||||||
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
|
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
rv := reflect.Indirect(reflect.ValueOf(r))
|
||||||
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
ev := reflect.Indirect(reflect.ValueOf(e))
|
||||||
|
if rv.IsValid() != ev.IsValid() {
|
||||||
|
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
got := rv.FieldByName(name).Interface()
|
||||||
|
expect := ev.FieldByName(name).Interface()
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
AssertEqual(t, got, expect)
|
AssertEqual(t, got, expect)
|
||||||
})
|
})
|
||||||
|
@ -131,3 +131,20 @@ func ToString(value interface{}) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const nestedRelationSplit = "__"
|
||||||
|
|
||||||
|
// NestedRelationName nested relationships like `Manager__Company`
|
||||||
|
func NestedRelationName(prefix, name string) string {
|
||||||
|
return prefix + nestedRelationSplit + name
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
|
||||||
|
func SplitNestedRelationName(name string) []string {
|
||||||
|
return strings.Split(name, nestedRelationSplit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinNestedRelationNames nested relationships like `Manager__Company`
|
||||||
|
func JoinNestedRelationNames(relationNames []string) string {
|
||||||
|
return strings.Join(relationNames, nestedRelationSplit)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user