Update context

This commit is contained in:
Jinzhu 2022-01-20 21:13:52 +08:00
parent e5894ca449
commit 0ca42129fa
14 changed files with 258 additions and 251 deletions

View File

@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
} }
case reflect.Struct: case reflect.Struct:
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
} }
for _, ref := range rel.References { for _, ref := range rel.References {
@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error {
primaryFields []*schema.Field primaryFields []*schema.Field
foreignKeys []string foreignKeys []string
updateMap = map[string]interface{}{} updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values}) tx.Not(clause.IN{Column: column, Values: values})
} }
@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
} }
@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values}) tx.Where(clause.IN{Column: column, Values: values})
} else { } else {
return ErrPrimaryKeyRequired return ErrPrimaryKeyRequired
} }
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
} }
@ -186,11 +186,11 @@ func (association *Association) Delete(values ...interface{}) error {
case schema.BelongsTo: case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -198,11 +198,11 @@ func (association *Association) Delete(values ...interface{}) error {
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -228,11 +228,11 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -241,11 +241,11 @@ func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
// clean up deleted values's foreign key // clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) { cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero { if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() { switch fieldValue.Kind() {
@ -253,7 +253,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ { for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -261,23 +261,23 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
association.Error = rel.Field.Set(data, validFieldValues.Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct: case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue) primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break break
} }
if rel.JoinTable == nil { if rel.JoinTable == nil {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" { if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else { } else {
association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -329,14 +329,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() { switch rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() > 0 { if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
} }
} }
case reflect.Struct: case reflect.Struct:
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -344,7 +344,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem() elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
if clear { if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
} }
@ -373,7 +373,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
if association.Error == nil { if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
} }
} }
} }
@ -421,7 +421,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
// clear old data // clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err association.Error = err
break break
} }
@ -429,7 +429,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if association.Relationship.JoinTable == nil { if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err association.Error = err
break break
} }
@ -453,12 +453,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
case reflect.Struct: case reflect.Struct:
// clear old data // clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil { if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -475,7 +475,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
for _, assignBack := range assignBacks { for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 { if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else { } else {
@ -486,7 +486,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
func (association *Association) buildCondition() *DB { func (association *Association) buildCondition() *DB {
var ( var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )

View File

@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
setupReferences := func(obj reflect.Value, elem reflect.Value) { setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem) pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(obj, pv)) db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv dest[ref.ForeignKey.DBName] = pv
@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
break break
} }
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
objs = append(objs, obj) objs = append(objs, obj)
if isPtr { if isPtr {
elems = reflect.Append(elems, rv) elems = reflect.Append(elems, rv)
@ -76,8 +76,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
} }
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct { if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(obj) rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(rv, fv)) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
} }
} }
@ -149,8 +149,8 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr { if f.Kind() != reflect.Ptr {
f = f.Addr() f = f.Addr()
} }
@ -158,10 +158,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns := make([]string, 0, len(rel.References)) assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv) ref.ForeignKey.Set(db.Statement.Context, f, fv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue) ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)
} }
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{} identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v) pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
ref.ForeignKey.Set(elem, pv) ref.ForeignKey.Set(db.Statement.Context, elem, pv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue) ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)
} }
} }
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields { for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(elem); !ok { if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv) relPrimaryValues = append(relPrimaryValues, pfv)
} }
} }
@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joinValue := reflect.New(rel.JoinTable.ModelType) joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
ref.ForeignKey.Set(joinValue, fv) ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue) ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)
} else { } else {
fv, _ := ref.PrimaryKey.ValueOf(elem) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
ref.ForeignKey.Set(joinValue, fv) ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)
} }
} }
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)

View File

@ -117,9 +117,9 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero { if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
} }
} }
@ -130,16 +130,16 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
} }
} }
} }
case reflect.Struct: case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero { if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)
} }
} }
} }
@ -219,23 +219,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values[i] = make([]interface{}, len(values.Columns)) values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface) field.Set(stmt.Context, rv, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime) field.Set(stmt.Context, rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(rv, curTime) field.Set(stmt.Context, rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(rv); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 { if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
} }
@ -259,23 +259,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface) field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime) field.Set(stmt.Context, stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(stmt.ReflectValue, curTime) field.Set(stmt.Context, stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue) values.Values[0] = append(values.Values[0], rvOfvalue)
} }

View File

@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
switch rel.Type { switch rel.Type {
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false withoutConditions := false
@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
} }
} }
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values}) queryConds = append(queryConds, clause.IN{Column: column, Values: values})
@ -123,7 +123,7 @@ func Delete(config *Config) func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Delete{}) db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -131,7 +131,7 @@ func Delete(config *Config) func(db *gorm.DB) {
} }
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {

View File

@ -48,7 +48,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 { if len(joinForeignValues) == 0 {
return return
} }
@ -63,11 +63,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < joinResults.Len(); i++ { for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i) joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields { for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinIndexValue) fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue)
} }
for idx, field := range joinRelForeignFields { for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue)
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -76,7 +76,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields)
} else { } else {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -92,7 +92,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 { if len(foreignValues) == 0 {
return return
} }
@ -125,17 +125,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct: case reflect.Struct:
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default: default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default: default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
} }
} }
} }
@ -143,7 +143,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(elem) fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem)
} }
datas, ok := identityMap[utils.ToStringKey(fieldValues...)] datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
@ -154,7 +154,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
for _, data := range datas { for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(data) reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
} }
@ -162,12 +162,12 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue) reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() { switch reflectFieldValue.Kind() {
case reflect.Struct: case reflect.Struct:
rel.Field.Set(data, elem.Interface()) rel.Field.Set(db.Statement.Context, data, elem.Interface())
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())
} else { } else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
} }
} }
} }

View File

@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) {
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields { for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
} }
} }

View File

@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok { if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])
} }
} }
} }
@ -137,13 +137,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value) field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
} }
} }
case reflect.Struct: case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() { if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value) field.Set(stmt.Context, stmt.ReflectValue, value)
} }
} }
default: default:
@ -165,7 +165,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool var notZero bool
for idx, field := range stmt.Schema.PrimaryFields { for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value} exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero notZero = notZero || !isZero
} }
@ -178,7 +178,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }
@ -258,7 +258,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field := updatingSchema.LookUpField(dbName); field != nil { if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue) value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond { if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano() value = stmt.DB.NowFunc().UnixNano()
@ -278,7 +278,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
} else { } else {
if value, isZero := field.ValueOf(updatingValue); !isZero { if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }

View File

@ -83,7 +83,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
case reflect.Struct: case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields { for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero { if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
} }
@ -199,7 +199,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
@ -216,11 +216,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil { if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value))
} }
case clause.Column: case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value))
} }
} }
} else if andCond, ok := expr.(clause.AndConditions); ok { } else if andCond, ok := expr.(clause.AndConditions); ok {
@ -238,9 +238,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
case reflect.Struct: case reflect.Struct:
for _, f := range s.Fields { for _, f := range s.Fields {
if f.Readable { if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero { if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v)) tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v))
} }
} }
} }

View File

@ -77,11 +77,11 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
if sch != nil { if sch != nil {
for idx, column := range columns { for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable { if field := sch.LookUpField(column); field != nil && field.Readable {
field.Set(reflectValue, values[idx]) field.Set(db.Statement.Context, reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok { if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(reflectValue) relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
value := reflect.ValueOf(values[idx]).Elem() value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
@ -91,7 +91,7 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
} }
field.Set(relValue, values[idx]) field.Set(db.Statement.Context, relValue, values[idx])
} }
} }
} }
@ -244,7 +244,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflectValue.Index(int(db.RowsAffected)) elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing { if onConflictDonothing {
for _, field := range fields { for _, field := range fields {
if _, ok := field.ValueOf(elem); !ok { if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++ db.RowsAffected++
goto BEGIN goto BEGIN
} }

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
@ -68,9 +69,9 @@ type Field struct {
Schema *Schema Schema *Schema
EmbeddedSchema *Schema EmbeddedSchema *Schema
OwnerSchema *Schema OwnerSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value ReflectValueOf func(context.Context, reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool) ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error Set func(context.Context, reflect.Value, interface{}) error
IgnoreMigration bool IgnoreMigration bool
} }
@ -408,22 +409,26 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
return field return field
} }
type GORMValuer interface {
GormValue(context.Context, *Field) (interface{}, error)
}
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter() {
// ValueOf // ValueOf
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero() return fieldValue.Interface(), fieldValue.IsZero()
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero() return fieldValue.Interface(), fieldValue.IsZero()
} }
default: default:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index { for _, idx := range field.StructField.Index {
@ -450,15 +455,15 @@ func (field *Field) setupValuerAndSetter() {
// ReflectValueOf // ReflectValueOf
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.ReflectValueOf = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]) return reflect.Indirect(value).Field(field.StructField.Index[0])
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValueOf = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
} }
default: default:
field.ReflectValueOf = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
v := reflect.Indirect(value) v := reflect.Indirect(value)
for idx, fieldIdx := range field.StructField.Index { for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 { if fieldIdx >= 0 {
@ -483,22 +488,22 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
if v == nil { if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
// Optimal value type acquisition for v // Optimal value type acquisition for v
reflectValType := reflectV.Type() reflectValType := reflectV.Type()
if reflectValType.AssignableTo(field.FieldType) { if reflectValType.AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
return return
} else if reflectValType.ConvertibleTo(field.FieldType) { } else if reflectValType.ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
return return
} else if field.FieldType.Kind() == reflect.Ptr { } else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
fieldType := field.FieldType.Elem() fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(fieldType) { if reflectValType.AssignableTo(fieldType) {
@ -521,13 +526,13 @@ func (field *Field) setupValuerAndSetter() {
if reflectV.Kind() == reflect.Ptr { if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() { if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
err = setter(value, reflectV.Elem().Interface()) err = setter(ctx, value, reflectV.Elem().Interface())
} }
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = setter(value, v) err = setter(ctx, value, v)
} }
} else { } else {
return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) return fmt.Errorf("failed to set value %+v to field %s", v, field.Name)
@ -540,191 +545,191 @@ func (field *Field) setupValuerAndSetter() {
// Set // Set
switch field.FieldType.Kind() { switch field.FieldType.Kind() {
case reflect.Bool: case reflect.Bool:
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case bool: case bool:
field.ReflectValueOf(value).SetBool(data) field.ReflectValueOf(ctx, value).SetBool(data)
case *bool: case *bool:
if data != nil { if data != nil {
field.ReflectValueOf(value).SetBool(*data) field.ReflectValueOf(ctx, value).SetBool(*data)
} else { } else {
field.ReflectValueOf(value).SetBool(false) field.ReflectValueOf(ctx, value).SetBool(false)
} }
case int64: case int64:
if data > 0 { if data > 0 {
field.ReflectValueOf(value).SetBool(true) field.ReflectValueOf(ctx, value).SetBool(true)
} else { } else {
field.ReflectValueOf(value).SetBool(false) field.ReflectValueOf(ctx, value).SetBool(false)
} }
case string: case string:
b, _ := strconv.ParseBool(data) b, _ := strconv.ParseBool(data)
field.ReflectValueOf(value).SetBool(b) field.ReflectValueOf(ctx, value).SetBool(b)
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return nil return nil
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case int64: case int64:
field.ReflectValueOf(value).SetInt(data) field.ReflectValueOf(ctx, value).SetInt(data)
case int: case int:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int8: case int8:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int16: case int16:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int32: case int32:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint: case uint:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint8: case uint8:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint16: case uint16:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint32: case uint32:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint64: case uint64:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float32: case float32:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float64: case float64:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(ctx, value, string(data))
case string: case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil { if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetInt(i) field.ReflectValueOf(ctx, value).SetInt(i)
} else { } else {
return err return err
} }
case time.Time: case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano()) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else { } else {
field.ReflectValueOf(value).SetInt(data.Unix()) field.ReflectValueOf(ctx, value).SetInt(data.Unix())
} }
case *time.Time: case *time.Time:
if data != nil { if data != nil {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano()) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else { } else {
field.ReflectValueOf(value).SetInt(data.Unix()) field.ReflectValueOf(ctx, value).SetInt(data.Unix())
} }
} else { } else {
field.ReflectValueOf(value).SetInt(0) field.ReflectValueOf(ctx, value).SetInt(0)
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case uint64: case uint64:
field.ReflectValueOf(value).SetUint(data) field.ReflectValueOf(ctx, value).SetUint(data)
case uint: case uint:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint8: case uint8:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint16: case uint16:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint32: case uint32:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int64: case int64:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int: case int:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int8: case int8:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int16: case int16:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int32: case int32:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float32: case float32:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float64: case float64:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(ctx, value, string(data))
case time.Time: case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
} else { } else {
field.ReflectValueOf(value).SetUint(uint64(data.Unix())) field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
} }
case string: case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil { if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetUint(i) field.ReflectValueOf(ctx, value).SetUint(i)
} else { } else {
return err return err
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case float64: case float64:
field.ReflectValueOf(value).SetFloat(data) field.ReflectValueOf(ctx, value).SetFloat(data)
case float32: case float32:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int64: case int64:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int: case int:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int8: case int8:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int16: case int16:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int32: case int32:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint: case uint:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint8: case uint8:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint16: case uint16:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint32: case uint32:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint64: case uint64:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(ctx, value, string(data))
case string: case string:
if i, err := strconv.ParseFloat(data, 64); err == nil { if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValueOf(value).SetFloat(i) field.ReflectValueOf(ctx, value).SetFloat(i)
} else { } else {
return err return err
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
case reflect.String: case reflect.String:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case string: case string:
field.ReflectValueOf(value).SetString(data) field.ReflectValueOf(ctx, value).SetString(data)
case []byte: case []byte:
field.ReflectValueOf(value).SetString(string(data)) field.ReflectValueOf(ctx, value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(value).SetString(utils.ToString(data)) field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
case float64, float32: case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
@ -732,41 +737,41 @@ func (field *Field) setupValuerAndSetter() {
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) { switch fieldValue.Elem().Interface().(type) {
case time.Time: case time.Time:
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
if data != nil { if data != nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
} else { } else {
field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
} }
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(t)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return nil return nil
} }
case *time.Time: case *time.Time:
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
fieldValue.Elem().Set(reflect.ValueOf(v)) fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
if v == "" { if v == "" {
return nil return nil
@ -778,27 +783,27 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return nil return nil
} }
default: default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner // pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
return field.Set(value, reflectV.Elem().Interface()) return field.Set(ctx, value, reflectV.Elem().Interface())
} }
} else { } else {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
@ -813,30 +818,30 @@ func (field *Field) setupValuerAndSetter() {
} }
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok { } else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner // struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
return field.Set(value, reflectV.Elem().Interface()) return field.Set(ctx, value, reflectV.Elem().Interface())
} }
} else { } else {
if valuer, ok := v.(driver.Valuer); ok { if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value() v, _ = valuer.Value()
} }
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else { } else {
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
} }
} }

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -576,7 +577,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
return &constraint return &constraint
} }
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table table := rel.FieldSchema.Table
foreignFields := []*Field{} foreignFields := []*Field{}
relForeignKeys := []string{} relForeignKeys := []string{}
@ -616,7 +617,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
} }
} }
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues) column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"context"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
@ -59,13 +60,13 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
} }
// GetRelationsValues get relations's values from a reflect value // GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels { for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) { appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(value); !isZero { if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(value)) result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
switch result.Kind() { switch result.Kind() {
case reflect.Struct: case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr()) reflectResults = reflect.Append(reflectResults, result.Addr())
@ -97,7 +98,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle
} }
// GetIdentityFieldValuesMap get identity map from fields // GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var ( var (
results = [][]interface{}{} results = [][]interface{}{}
dataResults = map[string][]reflect.Value{} dataResults = map[string][]reflect.Value{}
@ -110,7 +111,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
results = [][]interface{}{make([]interface{}, len(fields))} results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields { for idx, field := range fields {
results[0][idx], zero = field.ValueOf(reflectValue) results[0][idx], zero = field.ValueOf(ctx, reflectValue)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -135,7 +136,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
fieldValues := make([]interface{}, len(fields)) fieldValues := make([]interface{}, len(fields))
notZero = false notZero = false
for idx, field := range fields { for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(elem) fieldValues[idx], zero = field.ValueOf(ctx, elem)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -155,12 +156,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
} }
// GetIdentityFieldValuesMapFromValues get identity map from fields // GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
resultsMap := map[string][]reflect.Value{} resultsMap := map[string][]reflect.Value{}
results := [][]interface{}{} results := [][]interface{}{}
for _, v := range values { for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm { for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...) resultsMap[k] = append(resultsMap[k], v...)
} }

View File

@ -135,7 +135,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
stmt.SetColumn(sd.Field.DBName, curTime, true) stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil { if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -143,7 +143,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
} }
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {

View File

@ -391,7 +391,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields { for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name] selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
@ -405,7 +405,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields { for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name] selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
@ -564,7 +564,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
switch destValue.Kind() { switch destValue.Kind() {
case reflect.Struct: case reflect.Struct:
field.Set(destValue, value) field.Set(stmt.Context, destValue, value)
default: default:
stmt.AddError(ErrInvalidData) stmt.AddError(ErrInvalidData)
} }
@ -574,10 +574,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 { if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value) field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
} }
} else { } else {
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)
} }
case reflect.Struct: case reflect.Struct:
if !stmt.ReflectValue.CanAddr() { if !stmt.ReflectValue.CanAddr() {
@ -585,7 +585,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
return return
} }
field.Set(stmt.ReflectValue, value) field.Set(stmt.Context, stmt.ReflectValue, value)
} }
} else { } else {
stmt.AddError(ErrInvalidField) stmt.AddError(ErrInvalidField)
@ -605,7 +605,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool { changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(modelValue) fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok { if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := v[field.Name]; ok { if fv, ok := v[field.Name]; ok {
@ -619,7 +619,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
destValue = destValue.Elem() destValue = destValue.Elem()
} }
changedValue, zero := field.ValueOf(destValue) changedValue, zero := field.ValueOf(stmt.Context, destValue)
return !zero && !utils.AssertEqual(changedValue, fieldValue) return !zero && !utils.AssertEqual(changedValue, fieldValue)
} }
} }