diff --git a/association.go b/association.go index 62c25b71..09e79ca6 100644 --- a/association.go +++ b/association.go @@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error { primaryFields []*schema.Field foreignKeys []string updateMap = map[string]interface{}{} - relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) - if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } @@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error { } } - if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } @@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrPrimaryKeyRequired } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -186,11 +186,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.BelongsTo: tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -198,11 +198,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.HasOne, schema.HasMany: tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -228,11 +228,11 @@ func (association *Association) Delete(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -241,11 +241,11 @@ func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { // clean up deleted values's foreign key - relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { - if _, zero := rel.Field.ValueOf(data); !zero { - fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { @@ -253,7 +253,7 @@ func (association *Association) Delete(values ...interface{}) error { validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { @@ -261,23 +261,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - association.Error = rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { break } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -329,14 +329,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) @@ -344,7 +344,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) if clear { fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } @@ -373,7 +373,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) } } } @@ -421,7 +421,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { association.Error = err break } @@ -429,7 +429,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { association.Error = err break } @@ -453,12 +453,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: // clear old data if clear && len(values) == 0 { - association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -475,7 +475,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } for _, assignBack := range assignBacks { - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) if assignBack.Index > 0 { reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { @@ -486,7 +486,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ func (association *Association) buildCondition() *DB { var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) diff --git a/callbacks/associations.go b/callbacks/associations.go index 75bd6c6a..d6fd21de 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { setupReferences := func(obj reflect.Value, elem reflect.Value) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv @@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { break } - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value objs = append(objs, obj) if isPtr { elems = reflect.Append(elems, rv) @@ -76,8 +76,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value if rv.Kind() != reflect.Ptr { rv = rv.Addr() } @@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) if rv.Kind() != reflect.Ptr { rv = rv.Addr() } for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) } } @@ -149,8 +149,8 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) if f.Kind() != reflect.Ptr { f = f.Addr() } @@ -158,10 +158,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + ref.ForeignKey.Set(db.Statement.Context, f, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + ref.ForeignKey.Set(db.Statement.Context, elem, pv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) } } relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { - if pfv, ok := pf.ValueOf(elem); !ok { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } @@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) } } joins = reflect.Append(joins, joinValue) } appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) diff --git a/callbacks/create.go b/callbacks/create.go index 29113128..b0964e2b 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -117,9 +117,9 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -130,16 +130,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) } } } @@ -219,23 +219,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(rv, field.DefaultValueInterface) + field.Set(stmt.Context, rv, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } @@ -259,23 +259,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.ReflectValue, field.DefaultValueInterface) + field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 7f1e09ce..1fb5261c 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { switch rel.Type { case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false @@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { } } - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) queryConds = append(queryConds, clause.IN{Column: column, Values: values}) @@ -123,7 +123,7 @@ func Delete(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -131,7 +131,7 @@ func Delete(config *Config) func(db *gorm.DB) { } if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { diff --git a/callbacks/preload.go b/callbacks/preload.go index 41405a22..2363a8ca 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -48,7 +48,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { return } @@ -63,11 +63,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinIndexValue) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +76,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,7 +92,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { return } @@ -125,17 +125,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -143,7 +143,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(elem) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] @@ -154,7 +154,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) + reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,12 +162,12 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(data, elem.Interface()) + rel.Field.Set(db.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 49086354..03798859 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) } } diff --git a/callbacks/update.go b/callbacks/update.go index 511e994e..4f07ca30 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) } } } @@ -137,13 +137,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { if stmt.ReflectValue.CanAddr() { - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } default: @@ -165,7 +165,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) exprs[idx] = clause.Eq{Column: field.DBName, Value: value} notZero = notZero || !isZero } @@ -178,7 +178,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -258,7 +258,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) + value, isZero := field.ValueOf(stmt.Context, updatingValue) if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() @@ -278,7 +278,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } diff --git a/finisher_api.go b/finisher_api.go index 3a179977..d2a8b981 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -83,7 +83,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { - if _, isZero := pf.ValueOf(reflectValue); isZero { + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { return tx.callbacks.Create().Execute(tx) } } @@ -199,7 +199,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -216,11 +216,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { switch column := eq.Column.(type) { case string: if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { @@ -238,9 +238,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(reflectValue); !isZero { + if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) } } } diff --git a/scan.go b/scan.go index b03b79b4..64ea8dbd 100644 --- a/scan.go +++ b/scan.go @@ -77,11 +77,11 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re if sch != nil { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(reflectValue, values[idx]) + field.Set(db.Statement.Context, reflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(reflectValue) + relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) value := reflect.ValueOf(values[idx]).Elem() if relValue.Kind() == reflect.Ptr && relValue.IsNil() { @@ -91,7 +91,7 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re relValue.Set(reflect.New(relValue.Type().Elem())) } - field.Set(relValue, values[idx]) + field.Set(db.Statement.Context, relValue, values[idx]) } } } @@ -244,7 +244,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflectValue.Index(int(db.RowsAffected)) if onConflictDonothing { for _, field := range fields { - if _, ok := field.ValueOf(elem); !ok { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { db.RowsAffected++ goto BEGIN } diff --git a/schema/field.go b/schema/field.go index 485bbdf3..f060bc46 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,6 +1,7 @@ package schema import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -68,9 +69,9 @@ type Field struct { Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema - ReflectValueOf func(reflect.Value) reflect.Value - ValueOf func(reflect.Value) (value interface{}, zero bool) - Set func(reflect.Value, interface{}) error + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error IgnoreMigration bool } @@ -408,22 +409,34 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } +type GormFieldValuer interface { + GormFieldValue(context.Context, *Field) (interface{}, bool) +} + // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // ValueOf switch { case len(field.StructField.Index) == 1: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() + fv, zero := fieldValue.Interface(), fieldValue.IsZero() + if vr, ok := fv.(GormFieldValuer); ok { + fv, zero = vr.GormFieldValue(ctx, field) + } + return fv, zero } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - return fieldValue.Interface(), fieldValue.IsZero() + fv, zero := fieldValue.Interface(), fieldValue.IsZero() + if vr, ok := fv.(GormFieldValuer); ok { + fv, zero = vr.GormFieldValue(ctx, field) + } + return fv, zero } default: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { @@ -443,22 +456,26 @@ func (field *Field) setupValuerAndSetter() { } } } - return v.Interface(), v.IsZero() + fv, zero := v.Interface(), v.IsZero() + if vr, ok := fv.(GormFieldValuer); ok { + fv, zero = vr.GormFieldValue(ctx, field) + } + return fv, zero } } // ReflectValueOf switch { case len(field.StructField.Index) == 1: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]) } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) } default: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { v := reflect.Indirect(value) for idx, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { @@ -483,22 +500,22 @@ func (field *Field) setupValuerAndSetter() { } } - fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { if v == nil { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) // Optimal value type acquisition for v reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) fieldType := field.FieldType.Elem() if reflectValType.AssignableTo(fieldType) { @@ -521,13 +538,13 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - err = setter(value, reflectV.Elem().Interface()) + err = setter(ctx, value, reflectV.Elem().Interface()) } } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = setter(value, v) + err = setter(ctx, value, v) } } else { return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) @@ -540,191 +557,191 @@ func (field *Field) setupValuerAndSetter() { // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case bool: - field.ReflectValueOf(value).SetBool(data) + field.ReflectValueOf(ctx, value).SetBool(data) case *bool: if data != nil { - field.ReflectValueOf(value).SetBool(*data) + field.ReflectValueOf(ctx, value).SetBool(*data) } else { - field.ReflectValueOf(value).SetBool(false) + field.ReflectValueOf(ctx, value).SetBool(false) } case int64: if data > 0 { - field.ReflectValueOf(value).SetBool(true) + field.ReflectValueOf(ctx, value).SetBool(true) } else { - field.ReflectValueOf(value).SetBool(false) + field.ReflectValueOf(ctx, value).SetBool(false) } case string: b, _ := strconv.ParseBool(data) - field.ReflectValueOf(value).SetBool(b) + field.ReflectValueOf(ctx, value).SetBool(b) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: - field.ReflectValueOf(value).SetInt(data) + field.ReflectValueOf(ctx, value).SetInt(data) case int: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetInt(i) + field.ReflectValueOf(ctx, value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } } else { - field.ReflectValueOf(value).SetInt(0) + field.ReflectValueOf(ctx, value).SetInt(0) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: - field.ReflectValueOf(value).SetUint(data) + field.ReflectValueOf(ctx, value).SetUint(data) case uint: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) } else { - field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetUint(i) + field.ReflectValueOf(ctx, value).SetUint(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: - field.ReflectValueOf(value).SetFloat(data) + field.ReflectValueOf(ctx, value).SetFloat(data) case float32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValueOf(value).SetFloat(i) + field.ReflectValueOf(ctx, value).SetFloat(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.String: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: - field.ReflectValueOf(value).SetString(data) + field.ReflectValueOf(ctx, value).SetString(data) case []byte: - field.ReflectValueOf(value).SetString(string(data)) + field.ReflectValueOf(ctx, value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(utils.ToString(data)) + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) case float64, float32: - field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } @@ -732,41 +749,41 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) } else { - field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case *time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { if v == "" { return nil @@ -778,27 +795,27 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -813,30 +830,30 @@ func (field *Field) setupValuerAndSetter() { } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { - field.Set = func(value reflect.Value, v interface{}) (err error) { - return fallbackSetter(value, v, field.Set) + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) } } } diff --git a/schema/relationship.go b/schema/relationship.go index c5d3dcad..eae8ab0b 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,6 +1,7 @@ package schema import ( + "context" "fmt" "reflect" "strings" @@ -576,7 +577,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { return &constraint } -func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} @@ -616,7 +617,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } } - _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) diff --git a/schema/utils.go b/schema/utils.go index e005cc74..2720c530 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -1,6 +1,7 @@ package schema import ( + "context" "reflect" "regexp" "strings" @@ -59,13 +60,13 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct } // GetRelationsValues get relations's values from a reflect value -func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) switch result.Kind() { case reflect.Struct: reflectResults = reflect.Append(reflectResults, result.Addr()) @@ -97,7 +98,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle } // GetIdentityFieldValuesMap get identity map from fields -func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} @@ -110,7 +111,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - results[0][idx], zero = field.ValueOf(reflectValue) + results[0][idx], zero = field.ValueOf(ctx, reflectValue) notZero = notZero || !zero } @@ -135,7 +136,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(elem) + fieldValues[idx], zero = field.ValueOf(ctx, elem) notZero = notZero || !zero } @@ -155,12 +156,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map } // GetIdentityFieldValuesMapFromValues get identity map from fields -func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { resultsMap := map[string][]reflect.Value{} results := [][]interface{}{} for _, v := range values { - rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) for k, v := range rm { resultsMap[k] = append(resultsMap[k], v...) } diff --git a/soft_delete.go b/soft_delete.go index 4582161d..ba6d2118 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -135,7 +135,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -143,7 +143,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { diff --git a/statement.go b/statement.go index 23212642..cb471776 100644 --- a/statement.go +++ b/statement.go @@ -389,7 +389,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -403,7 +403,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(destValue, value) + field.Set(stmt.Context, destValue, value) default: stmt.AddError(ErrInvalidData) } @@ -572,10 +572,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } else { - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -583,7 +583,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } else { stmt.AddError(ErrInvalidField) @@ -603,7 +603,7 @@ func (stmt *Statement) Changed(fields ...string) bool { selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, _ := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { @@ -617,7 +617,7 @@ func (stmt *Statement) Changed(fields ...string) bool { destValue = destValue.Elem() } - changedValue, zero := field.ValueOf(destValue) + changedValue, zero := field.ValueOf(stmt.Context, destValue) return !zero && !utils.AssertEqual(changedValue, fieldValue) } }