From 61bc71a53b552b625c0ade37289941a0192a7493 Mon Sep 17 00:00:00 2001 From: daheige Date: Tue, 7 Sep 2021 22:07:04 +0800 Subject: [PATCH] refactor if...else and for logic --- callbacks/create.go | 82 +++++++++++++++---------- callbacks/delete.go | 139 ++++++++++++++++++++++--------------------- callbacks/preload.go | 38 ++++++------ callbacks/update.go | 15 ++--- 4 files changed, 149 insertions(+), 125 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8a3c593c..0a85dd78 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -62,46 +62,57 @@ func Create(config *Config) func(db *gorm.DB) { return } - db.RowsAffected, _ = result.RowsAffected() + db.RowsAffected, err = result.RowsAffected() + if err != nil { + db.AddError(err) + return + } if db.RowsAffected != 0 && db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + insertID, err := result.LastInsertId() + if err != nil { + db.AddError(err) + return + } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } + if insertID <= 0 { + // maybe log record + return + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } } } - } else { - db.AddError(err) + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } @@ -195,9 +206,14 @@ func CreateWithReturning(db *gorm.DB) { } } } else if !db.DryRun && db.Error == nil { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + + db.RowsAffected, err = result.RowsAffected() + if err != nil { db.AddError(err) } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 91659c51..57dadaee 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -25,82 +25,87 @@ func BeforeDelete(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } - if restricted { - for column, v := range selectColumns { - if v { - if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { - switch rel.Type { - case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) - withoutConditions := false - if db.Statement.Unscoped { - tx = tx.Unscoped() - } + for column, v := range selectColumns { + if !v { + continue + } - if len(db.Statement.Selects) > 0 { - selects := make([]string, 0, len(db.Statement.Selects)) - for _, s := range db.Statement.Selects { - if s == clause.Associations { - selects = append(selects, s) - } else if strings.HasPrefix(s, column+".") { - selects = append(selects, strings.TrimPrefix(s, column+".")) - } - } + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } - if len(selects) > 0 { - tx = tx.Select(selects) - } - } + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } - for _, cond := range queryConds { - if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { - withoutConditions = true - break - } - } - - if !withoutConditions { - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } - } - case schema.Many2Many: - var ( - queryConds = make([]clause.Expression, 0, len(rel.References)) - foreignFields = make([]*schema.Field, 0, len(rel.References)) - relForeignKeys = make([]string, 0, len(rel.References)) - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) - ) - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - foreignFields = append(foreignFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - queryConds = append(queryConds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) - column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) - queryConds = append(queryConds, clause.IN{Column: column, Values: values}) - - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) } } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return } } } + } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 9882590c..92019376 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -145,27 +145,29 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues[idx], _ = field.ValueOf(elem) } - if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { - for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) - if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { - reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) - } + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", + elem.Interface())) + } - reflectFieldValue = reflect.Indirect(reflectFieldValue) - switch reflectFieldValue.Kind() { - case reflect.Struct: - rel.Field.Set(data, reflectResults.Index(i).Interface()) - case reflect.Slice, reflect.Array: - if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) - } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) - } + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, elem.Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } - } else { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 75bb02db..1018e089 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -61,11 +61,12 @@ func Update(db *gorm.DB) { if db.Statement.SQL.String() == "" { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { + set := ConvertToAssignments(db.Statement) + if len(set) == 0 { return } + + db.Statement.AddClause(set) db.Statement.Build(db.Statement.BuildClauses...) } @@ -76,12 +77,12 @@ func Update(db *gorm.DB) { if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { + if err != nil { db.AddError(err) + return } + + db.RowsAffected, _ = result.RowsAffected() } } }