From 029c663d027d3f81b4f95d2434f1726c64c36682 Mon Sep 17 00:00:00 2001 From: daheige Date: Wed, 3 Nov 2021 22:01:25 +0800 Subject: [PATCH] adjust code for Create --- callbacks/create.go | 119 +++++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 52 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 656273fb..36e165a0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -65,66 +65,81 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build(db.Statement.BuildClauses...) } - if !db.DryRun && db.Error == nil { + isDryRun := !db.DryRun && db.Error == nil + if !isDryRun { + return + } - if ok, mode := hasReturning(db, supportReturning); ok { - if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { - mode |= gorm.ScanOnConflictDoNothing - } + ok, mode := hasReturning(db, supportReturning) + if ok { + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing } - if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, mode) - rows.Close() - } - } else { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } - if err != nil { - db.AddError(err) - return - } + rows, err := db.Statement.ConnPool.QueryContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() + } - db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + return + } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + result, err := db.Statement.ConnPool.ExecContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if err != nil { + db.AddError(err) + return + } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && + db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } - } else { - db.AddError(err) } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } + case reflect.Struct: + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } } }