Merge branch 'master' into patch-oracle
This commit is contained in:
commit
b2fdc25f73
5
.github/labels.json
vendored
5
.github/labels.json
vendored
@ -10,6 +10,11 @@
|
|||||||
"colour": "#EDEDED",
|
"colour": "#EDEDED",
|
||||||
"description": "general questions"
|
"description": "general questions"
|
||||||
},
|
},
|
||||||
|
"invalid_question": {
|
||||||
|
"name": "type:invalid question",
|
||||||
|
"colour": "#CF2E1F",
|
||||||
|
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
|
||||||
|
},
|
||||||
"with_playground": {
|
"with_playground": {
|
||||||
"name": "type:with reproduction steps",
|
"name": "type:with reproduction steps",
|
||||||
"colour": "#00ff00",
|
"colour": "#00ff00",
|
||||||
|
22
.github/workflows/invalid_question.yml
vendored
Normal file
22
.github/workflows/invalid_question.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
name: "Close invalid questions issues"
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "*/10 * * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
ACTIONS_STEP_DEBUG: true
|
||||||
|
steps:
|
||||||
|
- name: Close Stale Issues
|
||||||
|
uses: actions/stale@v3.0.7
|
||||||
|
with:
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
stale-issue-label: "status:stale"
|
||||||
|
days-before-stale: 0
|
||||||
|
days-before-close: 2
|
||||||
|
remove-stale-when-updated: true
|
||||||
|
only-labels: "type:invalid question"
|
||||||
|
|
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
uses: actions/stale@v3.0.7
|
uses: actions/stale@v3.0.7
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs."
|
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
stale-issue-label: "status:stale"
|
stale-issue-label: "status:stale"
|
||||||
days-before-stale: 0
|
days-before-stale: 0
|
||||||
days-before-close: 2
|
days-before-close: 2
|
||||||
|
16
callbacks.go
16
callbacks.go
@ -8,7 +8,6 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
@ -75,17 +74,22 @@ func (cs *callbacks) Raw() *processor {
|
|||||||
func (p *processor) Execute(db *DB) {
|
func (p *processor) Execute(db *DB) {
|
||||||
curTime := time.Now()
|
curTime := time.Now()
|
||||||
stmt := db.Statement
|
stmt := db.Statement
|
||||||
db.RowsAffected = 0
|
|
||||||
|
|
||||||
if stmt.Model == nil {
|
if stmt.Model == nil {
|
||||||
stmt.Model = stmt.Dest
|
stmt.Model = stmt.Dest
|
||||||
|
} else if stmt.Dest == nil {
|
||||||
|
stmt.Dest = stmt.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.Model != nil {
|
if stmt.Model != nil {
|
||||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
||||||
|
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
|
||||||
|
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
||||||
|
} else {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if stmt.Dest != nil {
|
if stmt.Dest != nil {
|
||||||
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
||||||
@ -154,7 +158,7 @@ func (p *processor) compile() (err error) {
|
|||||||
p.callbacks = callbacks
|
p.callbacks = callbacks
|
||||||
|
|
||||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||||
logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -177,7 +181,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Remove(name string) error {
|
func (c *callback) Remove(name string) error {
|
||||||
logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||||
c.name = name
|
c.name = name
|
||||||
c.remove = true
|
c.remove = true
|
||||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
@ -185,7 +189,7 @@ func (c *callback) Remove(name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||||
logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||||
c.name = name
|
c.name = name
|
||||||
c.handler = fn
|
c.handler = fn
|
||||||
c.replace = true
|
c.replace = true
|
||||||
@ -215,7 +219,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||||
logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||||
}
|
}
|
||||||
names = append(names, c.name)
|
names = append(names, c.name)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SaveBeforeAssociations(db *gorm.DB) {
|
func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
@ -65,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
|
||||||
DoNothing: true,
|
|
||||||
}).Create(elems.Interface()).Error) == nil {
|
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
@ -80,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||||||
rv = rv.Addr()
|
rv = rv.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
|
||||||
DoNothing: true,
|
|
||||||
}).Create(rv.Interface()).Error) == nil {
|
|
||||||
setupReferences(db.Statement.ReflectValue, rv)
|
setupReferences(db.Statement.ReflectValue, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||||
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
|
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
).Create(elems.Interface()).Error)
|
||||||
}).Create(elems.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||||
@ -167,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||||
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
|
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
).Create(f.Interface()).Error)
|
||||||
}).Create(f.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -229,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||||
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
|
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
).Create(elems.Interface()).Error)
|
||||||
}).Create(elems.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
|
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
|
||||||
|
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
appendToJoins(objs[i], elems.Index(i))
|
appendToJoins(objs[i], elems.Index(i))
|
||||||
@ -310,3 +304,32 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
|
||||||
|
if stmt.DB.FullSaveAssociations {
|
||||||
|
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
|
||||||
|
for _, dbName := range s.DBNames {
|
||||||
|
if !s.LookUpField(dbName).PrimaryKey {
|
||||||
|
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(defaultUpdatingColumns) > 0 {
|
||||||
|
var columns []clause.Column
|
||||||
|
if s.PrioritizedPrimaryField != nil {
|
||||||
|
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
|
||||||
|
} else {
|
||||||
|
for _, dbName := range s.PrimaryFieldDBNames {
|
||||||
|
columns = append(columns, clause.Column{Name: dbName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clause.OnConflict{
|
||||||
|
Columns: columns,
|
||||||
|
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clause.OnConflict{DoNothing: true}
|
||||||
|
}
|
||||||
|
@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||||||
deleteCallback := db.Callback().Delete()
|
deleteCallback := db.Callback().Delete()
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||||
|
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||||
deleteCallback.Register("gorm:delete", Delete)
|
deleteCallback.Register("gorm:delete", Delete)
|
||||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -87,9 +88,12 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if insertID > 0 {
|
if insertID > 0 {
|
||||||
|
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||||
|
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
@ -252,8 +256,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
stmt.SQL.Grow(stmt.ReflectValue.Len() * 15)
|
stmt.SQL.Grow(stmt.ReflectValue.Len() * 15)
|
||||||
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
||||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||||
|
if stmt.ReflectValue.Len() == 0 {
|
||||||
|
stmt.AddError(gorm.ErrEmptySlice)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||||
|
if !rv.IsValid() {
|
||||||
|
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||||
for idx, column := range values.Columns {
|
for idx, column := range values.Columns {
|
||||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||||
|
@ -2,6 +2,7 @@ package callbacks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -21,6 +22,85 @@ func BeforeDelete(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||||
|
if db.Error == nil && db.Statement.Schema != nil {
|
||||||
|
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||||
|
|
||||||
|
if restricted {
|
||||||
|
for column, v := range selectColumns {
|
||||||
|
if v {
|
||||||
|
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
|
||||||
|
switch rel.Type {
|
||||||
|
case schema.HasOne, schema.HasMany:
|
||||||
|
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
|
||||||
|
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
tx := db.Session(&gorm.Session{}).Model(modelValue)
|
||||||
|
withoutConditions := false
|
||||||
|
|
||||||
|
if len(db.Statement.Selects) > 0 {
|
||||||
|
var selects []string
|
||||||
|
for _, s := range db.Statement.Selects {
|
||||||
|
if s == clause.Associations {
|
||||||
|
selects = append(selects, s)
|
||||||
|
} else if strings.HasPrefix(s, column+".") {
|
||||||
|
selects = append(selects, strings.TrimPrefix(s, column+"."))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(selects) > 0 {
|
||||||
|
tx = tx.Select(selects)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cond := range queryConds {
|
||||||
|
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
|
||||||
|
withoutConditions = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !withoutConditions {
|
||||||
|
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case schema.Many2Many:
|
||||||
|
var (
|
||||||
|
queryConds []clause.Expression
|
||||||
|
foreignFields []*schema.Field
|
||||||
|
relForeignKeys []string
|
||||||
|
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
|
table = rel.JoinTable.Table
|
||||||
|
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||||
|
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||||
|
} else if ref.PrimaryValue != "" {
|
||||||
|
queryConds = append(queryConds, clause.Eq{
|
||||||
|
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: ref.PrimaryValue,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
|
||||||
|
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||||
|
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||||
|
|
||||||
|
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Delete(db *gorm.DB) {
|
func Delete(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
|
@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
|||||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(mapValues) == 0 {
|
||||||
|
stmt.AddError(gorm.ErrEmptySlice)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for idx, mapValue := range mapValues {
|
for idx, mapValue := range mapValues {
|
||||||
for k, v := range mapValue {
|
for k, v := range mapValue {
|
||||||
if stmt.Schema != nil {
|
if stmt.Schema != nil {
|
||||||
|
@ -107,6 +107,26 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
|||||||
|
|
||||||
fieldValues := make([]interface{}, len(relForeignFields))
|
fieldValues := make([]interface{}, len(relForeignFields))
|
||||||
|
|
||||||
|
// clean up old values before preloading
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
switch rel.Type {
|
||||||
|
case schema.HasMany, schema.Many2Many:
|
||||||
|
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
|
||||||
|
default:
|
||||||
|
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
|
||||||
|
}
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
switch rel.Type {
|
||||||
|
case schema.HasMany, schema.Many2Many:
|
||||||
|
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
|
||||||
|
default:
|
||||||
|
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < reflectResults.Len(); i++ {
|
for i := 0; i < reflectResults.Len(); i++ {
|
||||||
elem := reflectResults.Index(i)
|
elem := reflectResults.Index(i)
|
||||||
for idx, field := range relForeignFields {
|
for idx, field := range relForeignFields {
|
||||||
|
@ -11,11 +11,13 @@ func RowQuery(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun {
|
||||||
if _, ok := db.Get("rows"); ok {
|
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
|
||||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
} else {
|
} else {
|
||||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.RowsAffected = -1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
|
|
||||||
// Order specify order when retrieve records from database
|
// Order specify order when retrieve records from database
|
||||||
// db.Order("name DESC")
|
// db.Order("name DESC")
|
||||||
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
|
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package clause
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,19 +31,23 @@ func (expr Expr) Build(builder Builder) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, v := range []byte(expr.SQL) {
|
for _, v := range []byte(expr.SQL) {
|
||||||
if v == '?' {
|
if v == '?' && len(expr.Vars) > idx {
|
||||||
if afterParenthesis {
|
if afterParenthesis {
|
||||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||||
builder.AddVar(builder, expr.Vars[idx])
|
builder.AddVar(builder, expr.Vars[idx])
|
||||||
} else {
|
} else {
|
||||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
if rv.Len() == 0 {
|
||||||
|
builder.AddVar(builder, nil)
|
||||||
|
} else {
|
||||||
for i := 0; i < rv.Len(); i++ {
|
for i := 0; i < rv.Len(); i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
builder.WriteByte(',')
|
builder.WriteByte(',')
|
||||||
}
|
}
|
||||||
builder.AddVar(builder, rv.Index(i).Interface())
|
builder.AddVar(builder, rv.Index(i).Interface())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
builder.AddVar(builder, expr.Vars[idx])
|
builder.AddVar(builder, expr.Vars[idx])
|
||||||
}
|
}
|
||||||
@ -85,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) {
|
|||||||
for k, v := range value {
|
for k, v := range value {
|
||||||
namedMap[k] = v
|
namedMap[k] = v
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
modelType := reflectValue.Type()
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||||
|
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,7 +110,7 @@ func (expr NamedExpr) Build(builder Builder) {
|
|||||||
if v == '@' && !inName {
|
if v == '@' && !inName {
|
||||||
inName = true
|
inName = true
|
||||||
name = []byte{}
|
name = []byte{}
|
||||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' {
|
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
|
||||||
if inName {
|
if inName {
|
||||||
if nv, ok := namedMap[string(name)]; ok {
|
if nv, ok := namedMap[string(name)]; ok {
|
||||||
builder.AddVar(builder, nv)
|
builder.AddVar(builder, nv)
|
||||||
@ -106,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
builder.WriteByte(v)
|
builder.WriteByte(v)
|
||||||
} else if v == '?' {
|
} else if v == '?' && len(expr.Vars) > idx {
|
||||||
builder.AddVar(builder, expr.Vars[idx])
|
builder.AddVar(builder, expr.Vars[idx])
|
||||||
idx++
|
idx++
|
||||||
} else if inName {
|
} else if inName {
|
||||||
|
@ -37,6 +37,11 @@ func TestExpr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNamedExpr(t *testing.T) {
|
func TestNamedExpr(t *testing.T) {
|
||||||
|
type NamedArgument struct {
|
||||||
|
Name1 string
|
||||||
|
Name2 string
|
||||||
|
}
|
||||||
|
|
||||||
results := []struct {
|
results := []struct {
|
||||||
SQL string
|
SQL string
|
||||||
Result string
|
Result string
|
||||||
@ -66,6 +71,15 @@ func TestNamedExpr(t *testing.T) {
|
|||||||
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||||
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||||
|
}, {
|
||||||
|
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
|
||||||
|
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}},
|
||||||
|
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||||
|
}, {
|
||||||
|
SQL: "create table ? (? ?, ? ?)",
|
||||||
|
Vars: []interface{}{},
|
||||||
|
Result: "create table ? (? ?, ? ?)",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for idx, result := range results {
|
for idx, result := range results {
|
||||||
|
@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) {
|
|||||||
clause.Name = ""
|
clause.Name = ""
|
||||||
|
|
||||||
if v, ok := clause.Expression.(Limit); ok {
|
if v, ok := clause.Expression.(Limit); ok {
|
||||||
if limit.Limit == 0 && v.Limit > 0 {
|
if limit.Limit == 0 && v.Limit != 0 {
|
||||||
limit.Limit = v.Limit
|
limit.Limit = v.Limit
|
||||||
} else if limit.Limit < 0 {
|
|
||||||
limit.Limit = 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit.Offset == 0 && v.Offset > 0 {
|
if limit.Offset == 0 && v.Offset > 0 {
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
package clause
|
package clause
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// Where where clause
|
// Where where clause
|
||||||
type Where struct {
|
type Where struct {
|
||||||
Exprs []Expression
|
Exprs []Expression
|
||||||
@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wrapInParentheses := false
|
||||||
for idx, expr := range where.Exprs {
|
for idx, expr := range where.Exprs {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||||
@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(where.Exprs) > 1 {
|
||||||
|
switch v := expr.(type) {
|
||||||
|
case OrConditions:
|
||||||
|
if len(v.Exprs) == 1 {
|
||||||
|
if e, ok := v.Exprs[0].(Expr); ok {
|
||||||
|
sql := strings.ToLower(e.SQL)
|
||||||
|
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case AndConditions:
|
||||||
|
if len(v.Exprs) == 1 {
|
||||||
|
if e, ok := v.Exprs[0].(Expr); ok {
|
||||||
|
sql := strings.ToLower(e.SQL)
|
||||||
|
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case Expr:
|
||||||
|
sql := strings.ToLower(v.SQL)
|
||||||
|
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if wrapInParentheses {
|
||||||
|
builder.WriteString(`(`)
|
||||||
expr.Build(builder)
|
expr.Build(builder)
|
||||||
|
builder.WriteString(`)`)
|
||||||
|
wrapInParentheses = false
|
||||||
|
} else {
|
||||||
|
expr.Build(builder)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) {
|
|||||||
func And(exprs ...Expression) Expression {
|
func And(exprs ...Expression) Expression {
|
||||||
if len(exprs) == 0 {
|
if len(exprs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
} else if len(exprs) == 1 {
|
||||||
|
return exprs[0]
|
||||||
}
|
}
|
||||||
return AndConditions{Exprs: exprs}
|
return AndConditions{Exprs: exprs}
|
||||||
}
|
}
|
||||||
@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) {
|
|||||||
if len(not.Exprs) > 1 {
|
if len(not.Exprs) > 1 {
|
||||||
builder.WriteByte('(')
|
builder.WriteByte('(')
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, c := range not.Exprs {
|
for idx, c := range not.Exprs {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
builder.WriteString(" AND ")
|
builder.WriteString(" AND ")
|
||||||
@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) {
|
|||||||
negationBuilder.NegationBuild(builder)
|
negationBuilder.NegationBuild(builder)
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString("NOT ")
|
builder.WriteString("NOT ")
|
||||||
|
e, wrapInParentheses := c.(Expr)
|
||||||
|
if wrapInParentheses {
|
||||||
|
sql := strings.ToLower(e.SQL)
|
||||||
|
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
|
||||||
|
builder.WriteByte('(')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.Build(builder)
|
c.Build(builder)
|
||||||
|
|
||||||
|
if wrapInParentheses {
|
||||||
|
builder.WriteByte(')')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(not.Exprs) > 1 {
|
if len(not.Exprs) > 1 {
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
}
|
}
|
||||||
|
@ -27,4 +27,8 @@ var (
|
|||||||
ErrRegistered = errors.New("registered")
|
ErrRegistered = errors.New("registered")
|
||||||
// ErrInvalidField invalid field
|
// ErrInvalidField invalid field
|
||||||
ErrInvalidField = errors.New("invalid field")
|
ErrInvalidField = errors.New("invalid field")
|
||||||
|
// ErrEmptySlice empty slice found
|
||||||
|
ErrEmptySlice = errors.New("empty slice found")
|
||||||
|
// ErrDryRunModeUnsupported dry run mode unsupported
|
||||||
|
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
@ -51,7 +52,8 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
tx.callbacks.Update().Execute(tx)
|
tx.callbacks.Update().Execute(tx)
|
||||||
|
|
||||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||||
if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) {
|
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||||
|
if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
|
||||||
return tx.Create(value)
|
return tx.Create(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -331,22 +333,43 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Row() *sql.Row {
|
func (db *DB) Row() *sql.Row {
|
||||||
tx := db.getInstance()
|
tx := db.getInstance().InstanceSet("rows", false)
|
||||||
tx.callbacks.Row().Execute(tx)
|
tx.callbacks.Row().Execute(tx)
|
||||||
return tx.Statement.Dest.(*sql.Row)
|
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||||
|
if !ok && tx.DryRun {
|
||||||
|
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
||||||
|
}
|
||||||
|
return row
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Rows() (*sql.Rows, error) {
|
func (db *DB) Rows() (*sql.Rows, error) {
|
||||||
tx := db.Set("rows", true)
|
tx := db.getInstance().InstanceSet("rows", true)
|
||||||
tx.callbacks.Row().Execute(tx)
|
tx.callbacks.Row().Execute(tx)
|
||||||
return tx.Statement.Dest.(*sql.Rows), tx.Error
|
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||||
|
if !ok && tx.DryRun && tx.Error == nil {
|
||||||
|
tx.Error = ErrDryRunModeUnsupported
|
||||||
|
}
|
||||||
|
return rows, tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scan value to a struct
|
||||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||||
|
currentLogger, newLogger := db.Logger, logger.Recorder.New()
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = dest
|
tx.Logger = newLogger
|
||||||
tx.callbacks.Query().Execute(tx)
|
if rows, err := tx.Rows(); err != nil {
|
||||||
|
tx.AddError(err)
|
||||||
|
} else {
|
||||||
|
defer rows.Close()
|
||||||
|
if rows.Next() {
|
||||||
|
tx.ScanRows(rows, dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
|
||||||
|
return newLogger.SQL, tx.RowsAffected
|
||||||
|
}, tx.Error)
|
||||||
|
tx.Logger = currentLogger
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,9 +400,14 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
|||||||
|
|
||||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||||
tx := db.getInstance()
|
tx := db.getInstance()
|
||||||
tx.Error = tx.Statement.Parse(dest)
|
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||||
|
tx.AddError(err)
|
||||||
|
}
|
||||||
tx.Statement.Dest = dest
|
tx.Statement.Dest = dest
|
||||||
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest))
|
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||||
|
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||||
|
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
|
||||||
|
}
|
||||||
Scan(rows, tx, true)
|
Scan(rows, tx, true)
|
||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
13
gorm.go
13
gorm.go
@ -20,6 +20,8 @@ type Config struct {
|
|||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
NamingStrategy schema.Namer
|
NamingStrategy schema.Namer
|
||||||
|
// FullSaveAssociations full save associations
|
||||||
|
FullSaveAssociations bool
|
||||||
// Logger
|
// Logger
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
// NowFunc the function to be used when creating a new timestamp
|
// NowFunc the function to be used when creating a new timestamp
|
||||||
@ -64,6 +66,7 @@ type Session struct {
|
|||||||
WithConditions bool
|
WithConditions bool
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
AllowGlobalUpdate bool
|
AllowGlobalUpdate bool
|
||||||
|
FullSaveAssociations bool
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
NowFunc func() time.Time
|
NowFunc func() time.Time
|
||||||
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
txConfig.AllowGlobalUpdate = true
|
txConfig.AllowGlobalUpdate = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.FullSaveAssociations {
|
||||||
|
txConfig.FullSaveAssociations = true
|
||||||
|
}
|
||||||
|
|
||||||
if config.Context != nil {
|
if config.Context != nil {
|
||||||
tx.Statement = tx.Statement.clone()
|
tx.Statement = tx.Statement.clone()
|
||||||
tx.Statement.DB = tx
|
tx.Statement.DB = tx
|
||||||
@ -169,12 +176,15 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
|
|
||||||
if config.PrepareStmt {
|
if config.PrepareStmt {
|
||||||
if v, ok := db.cacheStore.Load("preparedStmt"); ok {
|
if v, ok := db.cacheStore.Load("preparedStmt"); ok {
|
||||||
|
tx.Statement = tx.Statement.clone()
|
||||||
preparedStmt := v.(*PreparedStmtDB)
|
preparedStmt := v.(*PreparedStmtDB)
|
||||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
ConnPool: db.Config.ConnPool,
|
ConnPool: db.Config.ConnPool,
|
||||||
Mux: preparedStmt.Mux,
|
Mux: preparedStmt.Mux,
|
||||||
Stmts: preparedStmt.Stmts,
|
Stmts: preparedStmt.Stmts,
|
||||||
}
|
}
|
||||||
|
txConfig.ConnPool = tx.Statement.ConnPool
|
||||||
|
txConfig.PrepareStmt = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,6 +326,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
|||||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||||
f.DataType = ref.ForeignKey.DataType
|
f.DataType = ref.ForeignKey.DataType
|
||||||
f.GORMDataType = ref.ForeignKey.GORMDataType
|
f.GORMDataType = ref.ForeignKey.GORMDataType
|
||||||
|
if f.Size == 0 {
|
||||||
|
f.Size = ref.ForeignKey.Size
|
||||||
|
}
|
||||||
ref.ForeignKey = f
|
ref.ForeignKey = f
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||||
|
@ -2,6 +2,7 @@ package logger
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
@ -19,6 +20,7 @@ const (
|
|||||||
Magenta = "\033[35m"
|
Magenta = "\033[35m"
|
||||||
Cyan = "\033[36m"
|
Cyan = "\033[36m"
|
||||||
White = "\033[37m"
|
White = "\033[37m"
|
||||||
|
BlueBold = "\033[34;1m"
|
||||||
MagentaBold = "\033[35;1m"
|
MagentaBold = "\033[35;1m"
|
||||||
RedBold = "\033[31;1m"
|
RedBold = "\033[31;1m"
|
||||||
YellowBold = "\033[33;1m"
|
YellowBold = "\033[33;1m"
|
||||||
@ -54,29 +56,33 @@ type Interface interface {
|
|||||||
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
|
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
var (
|
||||||
|
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
||||||
|
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||||
SlowThreshold: 100 * time.Millisecond,
|
SlowThreshold: 100 * time.Millisecond,
|
||||||
LogLevel: Warn,
|
LogLevel: Warn,
|
||||||
Colorful: true,
|
Colorful: true,
|
||||||
})
|
})
|
||||||
|
Recorder = traceRecorder{Interface: Default}
|
||||||
|
)
|
||||||
|
|
||||||
func New(writer Writer, config Config) Interface {
|
func New(writer Writer, config Config) Interface {
|
||||||
var (
|
var (
|
||||||
infoStr = "%s\n[info] "
|
infoStr = "%s\n[info] "
|
||||||
warnStr = "%s\n[warn] "
|
warnStr = "%s\n[warn] "
|
||||||
errStr = "%s\n[error] "
|
errStr = "%s\n[error] "
|
||||||
traceStr = "%s\n[%v] [rows:%d] %s"
|
traceStr = "%s\n[%.3fms] [rows:%v] %s"
|
||||||
traceWarnStr = "%s\n[%v] [rows:%d] %s"
|
traceWarnStr = "%s\n[%.3fms] [rows:%v] %s"
|
||||||
traceErrStr = "%s %s\n[%v] [rows:%d] %s"
|
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.Colorful {
|
if config.Colorful {
|
||||||
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||||
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||||
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||||
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
|
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
|
||||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &logger{
|
return &logger{
|
||||||
@ -133,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
|
|||||||
switch {
|
switch {
|
||||||
case err != nil && l.LogLevel >= Error:
|
case err != nil && l.LogLevel >= Error:
|
||||||
sql, rows := fc()
|
sql, rows := fc()
|
||||||
|
if rows == -1 {
|
||||||
|
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||||
|
} else {
|
||||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||||
|
}
|
||||||
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
|
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
|
||||||
sql, rows := fc()
|
sql, rows := fc()
|
||||||
|
if rows == -1 {
|
||||||
|
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||||
|
} else {
|
||||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||||
|
}
|
||||||
case l.LogLevel >= Info:
|
case l.LogLevel >= Info:
|
||||||
sql, rows := fc()
|
sql, rows := fc()
|
||||||
|
if rows == -1 {
|
||||||
|
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||||
|
} else {
|
||||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type traceRecorder struct {
|
||||||
|
Interface
|
||||||
|
BeginAt time.Time
|
||||||
|
SQL string
|
||||||
|
RowsAffected int64
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l traceRecorder) New() *traceRecorder {
|
||||||
|
return &traceRecorder{Interface: l.Interface}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||||
|
l.BeginAt = begin
|
||||||
|
l.SQL, l.RowsAffected = fc()
|
||||||
|
l.Err = err
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isPrintable(s []byte) bool {
|
func isPrintable(s []byte) bool {
|
||||||
@ -24,19 +26,38 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
|
|||||||
|
|
||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||||
var convertParams func(interface{}, int)
|
var convertParams func(interface{}, int)
|
||||||
var vars = make([]interface{}, len(avars))
|
var vars = make([]string, len(avars))
|
||||||
copy(vars, avars)
|
|
||||||
|
|
||||||
convertParams = func(v interface{}, idx int) {
|
convertParams = func(v interface{}, idx int) {
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
case bool:
|
case bool:
|
||||||
vars[idx] = fmt.Sprint(v)
|
vars[idx] = strconv.FormatBool(v)
|
||||||
case time.Time:
|
case time.Time:
|
||||||
if v.IsZero() {
|
if v.IsZero() {
|
||||||
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
||||||
} else {
|
} else {
|
||||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
|
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
|
||||||
}
|
}
|
||||||
|
case *time.Time:
|
||||||
|
if v != nil {
|
||||||
|
if v.IsZero() {
|
||||||
|
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
||||||
|
} else {
|
||||||
|
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vars[idx] = "NULL"
|
||||||
|
}
|
||||||
|
case fmt.Stringer:
|
||||||
|
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
case driver.Valuer:
|
||||||
|
reflectValue := reflect.ValueOf(v)
|
||||||
|
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||||
|
r, _ := v.Value()
|
||||||
|
convertParams(r, idx)
|
||||||
|
} else {
|
||||||
|
vars[idx] = "NULL"
|
||||||
|
}
|
||||||
case []byte:
|
case []byte:
|
||||||
if isPrintable(v) {
|
if isPrintable(v) {
|
||||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
@ -44,7 +65,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
vars[idx] = escaper + "<binary>" + escaper
|
vars[idx] = escaper + "<binary>" + escaper
|
||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
vars[idx] = fmt.Sprintf("%d", v)
|
vars[idx] = utils.ToString(v)
|
||||||
case float64, float32:
|
case float64, float32:
|
||||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||||
case string:
|
case string:
|
||||||
@ -70,18 +91,30 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, v := range vars {
|
for idx, v := range avars {
|
||||||
convertParams(v, idx)
|
convertParams(v, idx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if numericPlaceholder == nil {
|
if numericPlaceholder == nil {
|
||||||
for _, v := range vars {
|
var idx int
|
||||||
sql = strings.Replace(sql, "?", v.(string), 1)
|
var newSQL strings.Builder
|
||||||
|
|
||||||
|
for _, v := range []byte(sql) {
|
||||||
|
if v == '?' {
|
||||||
|
if len(vars) > idx {
|
||||||
|
newSQL.WriteString(vars[idx])
|
||||||
|
idx++
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
newSQL.WriteByte(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = newSQL.String()
|
||||||
} else {
|
} else {
|
||||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||||
for idx, v := range vars {
|
for idx, v := range vars {
|
||||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
|
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,13 +1,39 @@
|
|||||||
package logger_test
|
package logger_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/now"
|
"github.com/jinzhu/now"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type JSON json.RawMessage
|
||||||
|
|
||||||
|
func (j JSON) Value() (driver.Value, error) {
|
||||||
|
if len(j) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return json.RawMessage(j).MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExampleStruct struct {
|
||||||
|
Name string
|
||||||
|
Val string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ExampleStruct) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func format(v []byte, escaper string) string {
|
||||||
|
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
}
|
||||||
|
|
||||||
func TestExplainSQL(t *testing.T) {
|
func TestExplainSQL(t *testing.T) {
|
||||||
type role string
|
type role string
|
||||||
type password []byte
|
type password []byte
|
||||||
@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
tt = now.MustParse("2020-02-23 11:10:10")
|
tt = now.MustParse("2020-02-23 11:10:10")
|
||||||
myrole = role("admin")
|
myrole = role("admin")
|
||||||
pwd = password([]byte("pass"))
|
pwd = password([]byte("pass"))
|
||||||
|
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||||
|
js = JSON(jsVal)
|
||||||
|
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||||
|
es = ExampleStruct{Name: "test", Val: "test"}
|
||||||
)
|
)
|
||||||
|
|
||||||
results := []struct {
|
results := []struct {
|
||||||
@ -29,6 +59,12 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||||
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||||
@ -47,6 +83,18 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, r := range results {
|
for idx, r := range results {
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
// Migrator returns migrator
|
// Migrator returns migrator
|
||||||
func (db *DB) Migrator() Migrator {
|
func (db *DB) Migrator() Migrator {
|
||||||
return db.Dialector.Migrator(db)
|
return db.Dialector.Migrator(db.Session(&Session{WithConditions: true}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoMigrate run auto migration for given models
|
// AutoMigrate run auto migration for given models
|
||||||
|
@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||||
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
|
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
|
||||||
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
|
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
|
||||||
} else {
|
} else if field.DefaultValue != "(-)" {
|
||||||
expr.SQL += " DEFAULT " + field.DefaultValue
|
expr.SQL += " DEFAULT " + field.DefaultValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||||
|
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||||
|
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -297,10 +306,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
|
|||||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||||
|
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
|
||||||
return m.DB.Exec(
|
return m.DB.Exec(
|
||||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
|
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType,
|
||||||
).Error
|
).Error
|
||||||
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||||
})
|
})
|
||||||
@ -354,9 +365,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
alterColumn = true
|
alterColumn = true
|
||||||
} else {
|
} else {
|
||||||
// has size in data type and not equal
|
// has size in data type and not equal
|
||||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
|
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1)
|
||||||
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||||
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
|
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -386,8 +397,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
|
|
||||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
defer rows.Close()
|
||||||
columnTypes, err = rows.ColumnTypes()
|
columnTypes, err = rows.ColumnTypes()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@ -584,6 +596,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||||||
var (
|
var (
|
||||||
modelNames, orderedModelNames []string
|
modelNames, orderedModelNames []string
|
||||||
orderedModelNamesMap = map[string]bool{}
|
orderedModelNamesMap = map[string]bool{}
|
||||||
|
parsedSchemas = map[*schema.Schema]bool{}
|
||||||
valuesMap = map[string]Dependency{}
|
valuesMap = map[string]Dependency{}
|
||||||
insertIntoOrderedList func(name string)
|
insertIntoOrderedList func(name string)
|
||||||
parseDependence func(value interface{}, addToList bool)
|
parseDependence func(value interface{}, addToList bool)
|
||||||
@ -593,23 +606,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||||||
dep := Dependency{
|
dep := Dependency{
|
||||||
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
||||||
}
|
}
|
||||||
|
beDependedOn := map[*schema.Schema]bool{}
|
||||||
if err := dep.Parse(value); err != nil {
|
if err := dep.Parse(value); err != nil {
|
||||||
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
||||||
}
|
}
|
||||||
|
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parsedSchemas[dep.Statement.Schema] = true
|
||||||
|
|
||||||
for _, rel := range dep.Schema.Relationships.Relations {
|
for _, rel := range dep.Schema.Relationships.Relations {
|
||||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel.JoinTable != nil {
|
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||||
if rel.Schema != rel.FieldSchema {
|
beDependedOn[rel.FieldSchema] = true
|
||||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rel.JoinTable != nil {
|
||||||
// append join value
|
// append join value
|
||||||
defer func(joinValue interface{}) {
|
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||||
|
if !beDependedOn[rel.FieldSchema] {
|
||||||
|
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||||
|
} else {
|
||||||
|
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
parseDependence(fieldValue, autoAdd)
|
||||||
|
}
|
||||||
parseDependence(joinValue, autoAdd)
|
parseDependence(joinValue, autoAdd)
|
||||||
}(reflect.New(rel.JoinTable.ModelType).Interface())
|
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() {
|
|||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
db.Mux.RLock()
|
db.Mux.RLock()
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
|||||||
return stmt, nil
|
return stmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
|
stmt, err := db.ConnPool.PrepareContext(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
db.Stmts[query] = stmt
|
db.Stmts[query] = stmt
|
||||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||||
@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||||
stmt, err := db.prepare(query)
|
stmt, err := db.prepare(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = stmt.ExecContext(ctx, args...)
|
result, err = stmt.ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||||
stmt, err := db.prepare(query)
|
stmt, err := db.prepare(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = stmt.QueryContext(ctx, args...)
|
rows, err = stmt.QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
stmt, err := db.prepare(query)
|
stmt, err := db.prepare(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return stmt.QueryRowContext(ctx, args...)
|
return stmt.QueryRowContext(ctx, args...)
|
||||||
}
|
}
|
||||||
@ -99,10 +99,24 @@ type PreparedStmtTX struct {
|
|||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tx *PreparedStmtTX) Commit() error {
|
||||||
|
if tx.Tx != nil {
|
||||||
|
return tx.Tx.Commit()
|
||||||
|
}
|
||||||
|
return ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *PreparedStmtTX) Rollback() error {
|
||||||
|
if tx.Tx != nil {
|
||||||
|
return tx.Tx.Rollback()
|
||||||
|
}
|
||||||
|
return ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.Mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
@ -114,7 +128,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -128,9 +142,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
|
return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)
|
||||||
}
|
}
|
||||||
return &sql.Row{}
|
return &sql.Row{}
|
||||||
}
|
}
|
||||||
|
86
scan.go
86
scan.go
@ -2,22 +2,63 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
for idx, name := range columns {
|
||||||
|
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||||
|
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
} else if len(columnTypes) > 0 {
|
||||||
|
for idx, columnType := range columnTypes {
|
||||||
|
if columnType.ScanType() != nil {
|
||||||
|
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||||
|
} else {
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for idx := range columns {
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||||
|
for idx, column := range columns {
|
||||||
|
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||||
|
mapValue[column] = reflectValue.Interface()
|
||||||
|
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||||
|
mapValue[column], _ = valuer.Value()
|
||||||
|
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
||||||
|
mapValue[column] = string(b)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mapValue[column] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
values := make([]interface{}, len(columns))
|
values := make([]interface{}, len(columns))
|
||||||
|
db.RowsAffected = 0
|
||||||
|
|
||||||
switch dest := db.Statement.Dest.(type) {
|
switch dest := db.Statement.Dest.(type) {
|
||||||
case map[string]interface{}, *map[string]interface{}:
|
case map[string]interface{}, *map[string]interface{}:
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
for idx := range columns {
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
values[idx] = new(interface{})
|
prepareValues(values, db, columnTypes, columns)
|
||||||
}
|
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
@ -28,41 +69,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||||||
mapValue = *v
|
mapValue = *v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
scanIntoMap(mapValue, values, columns)
|
||||||
for idx, column := range columns {
|
|
||||||
if v, ok := values[idx].(*interface{}); ok {
|
|
||||||
if v == nil {
|
|
||||||
mapValue[column] = nil
|
|
||||||
} else {
|
|
||||||
mapValue[column] = *v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case *[]map[string]interface{}:
|
case *[]map[string]interface{}:
|
||||||
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
for idx := range columns {
|
prepareValues(values, db, columnTypes, columns)
|
||||||
values[idx] = new(interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
mapValue := map[string]interface{}{}
|
mapValue := map[string]interface{}{}
|
||||||
for idx, column := range columns {
|
scanIntoMap(mapValue, values, columns)
|
||||||
if v, ok := values[idx].(*interface{}); ok {
|
|
||||||
if v == nil {
|
|
||||||
mapValue[column] = nil
|
|
||||||
} else {
|
|
||||||
mapValue[column] = *v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*dest = append(*dest, mapValue)
|
*dest = append(*dest, mapValue)
|
||||||
}
|
}
|
||||||
case *int, *int64, *uint, *uint64, *float32, *float64:
|
case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time:
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
@ -114,7 +136,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pluck values into slice of data
|
// pluck values into slice of data
|
||||||
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
|
isPluck := false
|
||||||
|
if len(fields) == 1 {
|
||||||
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok {
|
||||||
|
isPluck = true
|
||||||
|
} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) {
|
||||||
|
isPluck = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
|
@ -18,6 +18,8 @@ type DataType string
|
|||||||
|
|
||||||
type TimeType int64
|
type TimeType int64
|
||||||
|
|
||||||
|
var TimeReflectType = reflect.TypeOf(time.Time{})
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UnixSecond TimeType = 1
|
UnixSecond TimeType = 1
|
||||||
UnixMillisecond TimeType = 2
|
UnixMillisecond TimeType = 2
|
||||||
@ -70,6 +72,8 @@ type Field struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
|
var err error
|
||||||
|
|
||||||
field := &Field{
|
field := &Field{
|
||||||
Name: fieldStruct.Name,
|
Name: fieldStruct.Name,
|
||||||
BindNames: []string{fieldStruct.Name},
|
BindNames: []string{fieldStruct.Name},
|
||||||
@ -100,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
var getRealFieldValue func(reflect.Value)
|
var getRealFieldValue func(reflect.Value)
|
||||||
getRealFieldValue = func(v reflect.Value) {
|
getRealFieldValue = func(v reflect.Value) {
|
||||||
rv := reflect.Indirect(v)
|
rv := reflect.Indirect(v)
|
||||||
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
|
||||||
for i := 0; i < rv.Type().NumField(); i++ {
|
for i := 0; i < rv.Type().NumField(); i++ {
|
||||||
newFieldType := rv.Type().Field(i).Type
|
newFieldType := rv.Type().Field(i).Type
|
||||||
for newFieldType.Kind() == reflect.Ptr {
|
for newFieldType.Kind() == reflect.Ptr {
|
||||||
@ -151,7 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||||
var err error
|
|
||||||
if field.Size, err = strconv.Atoi(num); err != nil {
|
if field.Size, err = strconv.Atoi(num); err != nil {
|
||||||
field.Size = -1
|
field.Size = -1
|
||||||
}
|
}
|
||||||
@ -177,33 +180,42 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
field.Comment = val
|
field.Comment = val
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// default value is function or null or blank (primary keys)
|
||||||
|
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
|
||||||
|
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
|
||||||
switch reflect.Indirect(fieldValue).Kind() {
|
switch reflect.Indirect(fieldValue).Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
field.DataType = Bool
|
field.DataType = Bool
|
||||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||||
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
|
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
|
||||||
|
schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
field.DataType = Int
|
field.DataType = Int
|
||||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||||
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
|
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
|
||||||
|
schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
field.DataType = Uint
|
field.DataType = Uint
|
||||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||||
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
|
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
|
||||||
|
schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
field.DataType = Float
|
field.DataType = Float
|
||||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||||
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
|
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
|
||||||
|
schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.DataType = String
|
field.DataType = String
|
||||||
isFunc := strings.Contains(field.DefaultValue, "(") &&
|
|
||||||
strings.Contains(field.DefaultValue, ")")
|
|
||||||
|
|
||||||
if field.HasDefaultValue && !isFunc {
|
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||||
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
||||||
field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
|
field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
|
||||||
field.DefaultValueInterface = field.DefaultValue
|
field.DefaultValueInterface = field.DefaultValue
|
||||||
@ -211,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
@ -316,7 +328,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
|
|
||||||
cacheStore := &sync.Map{}
|
cacheStore := &sync.Map{}
|
||||||
cacheStore.Store(embeddedCacheKey, true)
|
cacheStore.Store(embeddedCacheKey, true)
|
||||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
|
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,10 +343,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
|
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
|
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
|
||||||
ef.DBName = prefix + ef.DBName
|
ef.DBName = prefix + ef.DBName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ef.PrimaryKey {
|
||||||
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
||||||
ef.PrimaryKey = true
|
ef.PrimaryKey = true
|
||||||
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
||||||
@ -350,6 +363,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
ef.HasDefaultValue = false
|
ef.HasDefaultValue = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range field.TagSettings {
|
for k, v := range field.TagSettings {
|
||||||
ef.TagSettings[k] = v
|
ef.TagSettings[k] = v
|
||||||
@ -612,6 +626,14 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(value).SetUint(uint64(data))
|
||||||
case []byte:
|
case []byte:
|
||||||
return field.Set(value, string(data))
|
return field.Set(value, string(data))
|
||||||
|
case time.Time:
|
||||||
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
|
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
|
||||||
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
|
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||||
|
} else {
|
||||||
|
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
|
||||||
|
}
|
||||||
case string:
|
case string:
|
||||||
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
|
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
|
||||||
field.ReflectValueOf(value).SetUint(i)
|
field.ReflectValueOf(value).SetUint(i)
|
||||||
@ -671,7 +693,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case []byte:
|
case []byte:
|
||||||
field.ReflectValueOf(value).SetString(string(data))
|
field.ReflectValueOf(value).SetString(string(data))
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
field.ReflectValueOf(value).SetString(fmt.Sprint(data))
|
field.ReflectValueOf(value).SetString(utils.ToString(data))
|
||||||
case float64, float32:
|
case float64, float32:
|
||||||
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||||
default:
|
default:
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
type Namer interface {
|
type Namer interface {
|
||||||
TableName(table string) string
|
TableName(table string) string
|
||||||
ColumnName(table, column string) string
|
ColumnName(table, column string) string
|
||||||
JoinTableName(table string) string
|
JoinTableName(joinTable string) string
|
||||||
RelationshipFKName(Relationship) string
|
RelationshipFKName(Relationship) string
|
||||||
CheckerName(table, column string) string
|
CheckerName(table, column string) string
|
||||||
IndexName(table, column string) string
|
IndexName(table, column string) string
|
||||||
@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
|
|||||||
|
|
||||||
// JoinTableName convert string to join table name
|
// JoinTableName convert string to join table name
|
||||||
func (ns NamingStrategy) JoinTableName(str string) string {
|
func (ns NamingStrategy) JoinTableName(str string) string {
|
||||||
|
if strings.ToLower(str) == str {
|
||||||
|
return ns.TablePrefix + str
|
||||||
|
}
|
||||||
|
|
||||||
if ns.SingularTable {
|
if ns.SingularTable {
|
||||||
return ns.TablePrefix + toDBName(str)
|
return ns.TablePrefix + toDBName(str)
|
||||||
}
|
}
|
||||||
@ -49,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string {
|
|||||||
|
|
||||||
// RelationshipFKName generate fk name for relation
|
// RelationshipFKName generate fk name for relation
|
||||||
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
|
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
|
||||||
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name))
|
return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckerName generate checker name
|
// CheckerName generate checker name
|
||||||
func (ns NamingStrategy) CheckerName(table, column string) string {
|
func (ns NamingStrategy) CheckerName(table, column string) string {
|
||||||
return fmt.Sprintf("chk_%s_%s", table, column)
|
return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IndexName generate index name
|
// IndexName generate index name
|
||||||
func (ns NamingStrategy) IndexName(table, column string) string {
|
func (ns NamingStrategy) IndexName(table, column string) string {
|
||||||
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
|
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
|
||||||
|
idxName = strings.Replace(idxName, ".", "_", -1)
|
||||||
|
|
||||||
if utf8.RuneCountInString(idxName) > 64 {
|
if utf8.RuneCountInString(idxName) > 64 {
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
|
@ -32,3 +32,35 @@ func TestToDBName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamingStrategy(t *testing.T) {
|
||||||
|
var ns = NamingStrategy{
|
||||||
|
TablePrefix: "public.",
|
||||||
|
SingularTable: true,
|
||||||
|
}
|
||||||
|
idxName := ns.IndexName("public.table", "name")
|
||||||
|
|
||||||
|
if idxName != "idx_public_table_name" {
|
||||||
|
t.Errorf("invalid index name generated, got %v", idxName)
|
||||||
|
}
|
||||||
|
|
||||||
|
chkName := ns.CheckerName("public.table", "name")
|
||||||
|
if chkName != "chk_public_table_name" {
|
||||||
|
t.Errorf("invalid checker name generated, got %v", chkName)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTable := ns.JoinTableName("user_languages")
|
||||||
|
if joinTable != "public.user_languages" {
|
||||||
|
t.Errorf("invalid join table generated, got %v", joinTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTable2 := ns.JoinTableName("UserLanguage")
|
||||||
|
if joinTable2 != "public.user_language" {
|
||||||
|
t.Errorf("invalid join table generated, got %v", joinTable2)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := ns.TableName("Company")
|
||||||
|
if tableName != "public.company" {
|
||||||
|
t.Errorf("invalid table name generated, got %v", tableName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||||||
// use same data type for foreign keys
|
// use same data type for foreign keys
|
||||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||||
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
|
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
|
||||||
|
if relation.Polymorphic.PolymorphicID.Size == 0 {
|
||||||
|
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
|
||||||
|
}
|
||||||
|
|
||||||
relation.References = append(relation.References, &Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: primaryKeyField,
|
PrimaryKey: primaryKeyField,
|
||||||
@ -254,12 +257,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
|
Name: schema.Name + field.Name,
|
||||||
|
Type: schema.ModelType,
|
||||||
|
Tag: `gorm:"-"`,
|
||||||
|
})
|
||||||
|
|
||||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
relation.JoinTable.Name = many2many
|
relation.JoinTable.Name = many2many
|
||||||
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
|
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
|
||||||
relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields))
|
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
|
||||||
|
|
||||||
relName := relation.Schema.Name
|
relName := relation.Schema.Name
|
||||||
relRefName := relation.FieldSchema.Name
|
relRefName := relation.FieldSchema.Name
|
||||||
@ -290,11 +299,15 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// build references
|
// build references
|
||||||
for idx, f := range relation.JoinTable.Fields {
|
for _, f := range relation.JoinTable.Fields {
|
||||||
|
if f.Creatable || f.Readable || f.Updatable {
|
||||||
// use same data type for foreign keys
|
// use same data type for foreign keys
|
||||||
f.DataType = fieldsMap[f.Name].DataType
|
f.DataType = fieldsMap[f.Name].DataType
|
||||||
f.GORMDataType = fieldsMap[f.Name].GORMDataType
|
f.GORMDataType = fieldsMap[f.Name].GORMDataType
|
||||||
relation.JoinTable.PrimaryFields[idx] = f
|
if f.Size == 0 {
|
||||||
|
f.Size = fieldsMap[f.Name].Size
|
||||||
|
}
|
||||||
|
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||||
ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
||||||
|
|
||||||
if ownPriamryField {
|
if ownPriamryField {
|
||||||
@ -321,6 +334,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
OwnPrimaryKey: ownPriamryField,
|
OwnPrimaryKey: ownPriamryField,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type guessLevel int
|
type guessLevel int
|
||||||
@ -428,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
|||||||
// use same data type for foreign keys
|
// use same data type for foreign keys
|
||||||
foreignField.DataType = primaryFields[idx].DataType
|
foreignField.DataType = primaryFields[idx].DataType
|
||||||
foreignField.GORMDataType = primaryFields[idx].GORMDataType
|
foreignField.GORMDataType = primaryFields[idx].GORMDataType
|
||||||
|
if foreignField.Size == 0 {
|
||||||
|
foreignField.Size = primaryFields[idx].Size
|
||||||
|
}
|
||||||
|
|
||||||
relation.References = append(relation.References, &Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: primaryFields[idx],
|
PrimaryKey: primaryFields[idx],
|
||||||
|
@ -206,16 +206,39 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
|||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
|
Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
|
||||||
Refer uint
|
Refer uint
|
||||||
}
|
}
|
||||||
|
|
||||||
checkStructRelation(t, &User{}, Relation{
|
checkStructRelation(t, &User{}, Relation{
|
||||||
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||||
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
|
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
|
||||||
References: []Reference{
|
References: []Reference{
|
||||||
{"ID", "User", "UserReferID", "user_profiles", "", true},
|
{"ID", "User", "UserReferID", "user_profile", "", true},
|
||||||
{"ID", "Profile", "ProfileRefer", "user_profiles", "", false},
|
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildReadonlyMany2ManyRelation(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
|
||||||
|
Refer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
checkStructRelation(t, &User{}, Relation{
|
||||||
|
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||||
|
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
|
||||||
|
References: []Reference{
|
||||||
|
{"ID", "User", "UserReferID", "user_profile", "", true},
|
||||||
|
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -267,3 +290,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultipleMany2Many(t *testing.T) {
|
||||||
|
type Thing struct {
|
||||||
|
ID int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Person struct {
|
||||||
|
ID int
|
||||||
|
Likes []Thing `gorm:"many2many:likes"`
|
||||||
|
Dislikes []Thing `gorm:"many2many:dislikes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
checkStructRelation(t, &Person{},
|
||||||
|
Relation{
|
||||||
|
Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
|
||||||
|
JoinTable: JoinTable{Name: "likes", Table: "likes"},
|
||||||
|
References: []Reference{
|
||||||
|
{"ID", "Person", "PersonID", "likes", "", true},
|
||||||
|
{"ID", "Thing", "ThingID", "likes", "", false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Relation{
|
||||||
|
Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
|
||||||
|
JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"},
|
||||||
|
References: []Reference{
|
||||||
|
{"ID", "Person", "PersonID", "dislikes", "", true},
|
||||||
|
{"ID", "Thing", "ThingID", "dislikes", "", false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||||
tableName = tabler.TableName()
|
tableName = tabler.TableName()
|
||||||
}
|
}
|
||||||
|
if en, ok := namer.(embeddedNamer); ok {
|
||||||
|
tableName = en.Table
|
||||||
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
Name: modelType.Name(),
|
Name: modelType.Name(),
|
||||||
@ -133,7 +136,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||||||
|
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
// nonexistence or shortest path or first appear prioritized if has permission
|
// nonexistence or shortest path or first appear prioritized if has permission
|
||||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
|
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||||
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
|
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
|
||||||
schema.DBNames = append(schema.DBNames, field.DBName)
|
schema.DBNames = append(schema.DBNames, field.DBName)
|
||||||
}
|
}
|
||||||
@ -219,7 +222,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||||||
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
||||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.DataType == "" && field.Creatable {
|
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||||
if schema.parseRelation(field); schema.err != nil {
|
if schema.parseRelation(field); schema.err != nil {
|
||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package schema_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -194,6 +195,7 @@ func TestEmbeddedStruct(t *testing.T) {
|
|||||||
ID int
|
ID int
|
||||||
OwnerID int
|
OwnerID int
|
||||||
Name string
|
Name string
|
||||||
|
Ignored string `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Corp struct {
|
type Corp struct {
|
||||||
@ -211,15 +213,87 @@ func TestEmbeddedStruct(t *testing.T) {
|
|||||||
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||||
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range fields {
|
for _, f := range fields {
|
||||||
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
||||||
|
if f.Name != "Ignored" {
|
||||||
f.Creatable = true
|
f.Creatable = true
|
||||||
f.Updatable = true
|
f.Updatable = true
|
||||||
f.Readable = true
|
f.Readable = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomizedNamingStrategy struct {
|
||||||
|
schema.NamingStrategy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ns CustomizedNamingStrategy) ColumnName(table, column string) string {
|
||||||
|
baseColumnName := ns.NamingStrategy.ColumnName(table, column)
|
||||||
|
|
||||||
|
if table == "" {
|
||||||
|
return baseColumnName
|
||||||
|
}
|
||||||
|
|
||||||
|
s := strings.Split(table, "_")
|
||||||
|
|
||||||
|
var prefix string
|
||||||
|
switch len(s) {
|
||||||
|
case 1:
|
||||||
|
prefix = s[0][:3]
|
||||||
|
case 2:
|
||||||
|
prefix = s[0][:1] + s[1][:2]
|
||||||
|
default:
|
||||||
|
prefix = s[0][:1] + s[1][:1] + s[2][:1]
|
||||||
|
}
|
||||||
|
return prefix + "_" + baseColumnName
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
|
||||||
|
type CorpBase struct {
|
||||||
|
gorm.Model
|
||||||
|
OwnerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Company struct {
|
||||||
|
ID int
|
||||||
|
OwnerID int
|
||||||
|
Name string
|
||||||
|
Ignored string `gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Corp struct {
|
||||||
|
CorpBase
|
||||||
|
Base Company `gorm:"embedded;embeddedPrefix:company_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := []schema.Field{
|
||||||
|
{Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||||
|
{Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
||||||
|
if f.Name != "Ignored" {
|
||||||
|
f.Creatable = true
|
||||||
|
f.Updatable = true
|
||||||
|
f.Readable = true
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
|
|||||||
return columns, queryValues
|
return columns, queryValues
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type embeddedNamer struct {
|
||||||
|
Table string
|
||||||
|
Namer
|
||||||
|
}
|
||||||
|
@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||||
return []clause.Interface{
|
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
|
||||||
clause.Where{Exprs: []clause.Expression{
|
|
||||||
clause.Eq{
|
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
|
|
||||||
Value: nil,
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SoftDeleteQueryClause struct {
|
type SoftDeleteQueryClause struct {
|
||||||
|
17
statement.go
17
statement.go
@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||||||
writer.WriteString("(NULL)")
|
writer.WriteString("(NULL)")
|
||||||
}
|
}
|
||||||
case *DB:
|
case *DB:
|
||||||
subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance()
|
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance()
|
||||||
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
||||||
subdb.callbacks.Query().Execute(subdb)
|
subdb.callbacks.Query().Execute(subdb)
|
||||||
writer.WriteString(subdb.Statement.SQL.String())
|
writer.WriteString(subdb.Statement.SQL.String())
|
||||||
@ -298,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
|||||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
if _, ok := v[key].(driver.Valuer); ok {
|
||||||
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||||
|
} else if _, ok := v[key].(Valuer); ok {
|
||||||
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||||
|
} else {
|
||||||
values := make([]interface{}, reflectValue.Len())
|
values := make([]interface{}, reflectValue.Len())
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
values[i] = reflectValue.Index(i).Interface()
|
values[i] = reflectValue.Index(i).Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||||
}
|
}
|
||||||
@ -317,9 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
|||||||
if field.Readable {
|
if field.Readable {
|
||||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -330,9 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
|||||||
if field.Readable {
|
if field.Readable {
|
||||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -287,6 +288,30 @@ func TestCreateEmptyStruct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateEmptySlice(t *testing.T) {
|
||||||
|
var data = []User{}
|
||||||
|
if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice {
|
||||||
|
t.Errorf("no data should be created, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sliceMap = []map[string]interface{}{}
|
||||||
|
if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice {
|
||||||
|
t.Errorf("no data should be created, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateInvalidSlice(t *testing.T) {
|
||||||
|
users := []*User{
|
||||||
|
GetUser("invalid_slice_1", Config{}),
|
||||||
|
GetUser("invalid_slice_2", Config{}),
|
||||||
|
nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) {
|
||||||
|
t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateWithExistingTimestamp(t *testing.T) {
|
func TestCreateWithExistingTimestamp(t *testing.T) {
|
||||||
user := User{Name: "CreateUserExistingTimestamp"}
|
user := User{Name: "CreateUserExistingTimestamp"}
|
||||||
curTime := now.MustParse("2016-01-01")
|
curTime := now.MustParse("2016-01-01")
|
||||||
|
@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
FieldAllowSave3 string `gorm:"->:false;<-:create"`
|
FieldAllowSave3 string `gorm:"->:false;<-:create"`
|
||||||
FieldReadonly string `gorm:"->"`
|
FieldReadonly string `gorm:"->"`
|
||||||
FieldIgnore string `gorm:"-"`
|
FieldIgnore string `gorm:"-"`
|
||||||
AutoUnixCreateTime int64 `gorm:"autocreatetime"`
|
AutoUnixCreateTime int32 `gorm:"autocreatetime"`
|
||||||
AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"`
|
AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"`
|
||||||
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
|
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
|
||||||
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"`
|
AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"`
|
||||||
AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"`
|
AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"`
|
||||||
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"`
|
AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"`
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Migrator().DropTable(&CustomizeFieldStruct{})
|
DB.Migrator().DropTable(&CustomizeFieldStruct{})
|
||||||
@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
t.Fatalf("invalid result: %#v", result)
|
t.Fatalf("invalid result: %#v", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 {
|
if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 {
|
||||||
t.Fatalf("invalid create/update unix time: %#v", result)
|
t.Fatalf("invalid create/update unix time: %#v", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 {
|
if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 {
|
||||||
t.Fatalf("invalid create/update unix milli time: %#v", result)
|
t.Fatalf("invalid create/update unix milli time: %#v", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 {
|
if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 {
|
||||||
t.Fatalf("invalid create/update unix nano time: %#v", result)
|
t.Fatalf("invalid create/update unix nano time: %#v", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
var createWithDefaultTimeResult CustomizeFieldStruct
|
var createWithDefaultTimeResult CustomizeFieldStruct
|
||||||
DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name)
|
DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name)
|
||||||
|
|
||||||
if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
|
if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
|
||||||
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
|
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
|
if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
|
||||||
t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult)
|
t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
|
if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
|
||||||
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
|
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) {
|
|||||||
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteWithAssociations(t *testing.T) {
|
||||||
|
user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1})
|
||||||
|
|
||||||
|
if err := DB.Create(user).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil {
|
||||||
|
t.Fatalf("failed to delete user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Model(&user).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}),
|
||||||
|
*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}),
|
||||||
|
*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}),
|
||||||
|
*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(users).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil {
|
||||||
|
t.Fatalf("failed to delete user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Model(&users).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) {
|
|||||||
DB.Migrator().DropTable(&AdvancedUser{})
|
DB.Migrator().DropTable(&AdvancedUser{})
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
|
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
|
||||||
|
if DB.Dialector.Name() != "sqlite" {
|
||||||
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
12
tests/go.mod
12
tests/go.mod
@ -7,11 +7,13 @@ require (
|
|||||||
github.com/jinzhu/now v1.1.1
|
github.com/jinzhu/now v1.1.1
|
||||||
github.com/lib/pq v1.6.0
|
github.com/lib/pq v1.6.0
|
||||||
github.com/stevefan1999-personal/gorm-driver-oracle latest
|
github.com/stevefan1999-personal/gorm-driver-oracle latest
|
||||||
gorm.io/driver/mysql v1.0.0
|
gorm.io/driver/mysql v1.0.1
|
||||||
gorm.io/driver/postgres v1.0.0
|
gorm.io/driver/postgres v1.0.2
|
||||||
gorm.io/driver/sqlite v1.1.0
|
gorm.io/driver/sqlite v1.1.3
|
||||||
gorm.io/driver/sqlserver v1.0.1
|
gorm.io/driver/sqlserver v1.0.4
|
||||||
gorm.io/gorm v1.9.19
|
gorm.io/gorm v1.20.2
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
||||||
|
replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2
|
||||||
|
@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
|
|
||||||
t.Run("Pets", func(t *testing.T) {
|
t.Run("Pets", func(t *testing.T) {
|
||||||
if len(user.Pets) != len(expect.Pets) {
|
if len(user.Pets) != len(expect.Pets) {
|
||||||
t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
|
t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(user.Pets, func(i, j int) bool {
|
sort.Slice(user.Pets, func(i, j int) bool {
|
||||||
@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
|
|
||||||
t.Run("Toys", func(t *testing.T) {
|
t.Run("Toys", func(t *testing.T) {
|
||||||
if len(user.Toys) != len(expect.Toys) {
|
if len(user.Toys) != len(expect.Toys) {
|
||||||
t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
|
t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(user.Toys, func(i, j int) bool {
|
sort.Slice(user.Toys, func(i, j int) bool {
|
||||||
@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
|
|
||||||
t.Run("Team", func(t *testing.T) {
|
t.Run("Team", func(t *testing.T) {
|
||||||
if len(user.Team) != len(expect.Team) {
|
if len(user.Team) != len(expect.Team) {
|
||||||
t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
|
t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(user.Team, func(i, j int) bool {
|
sort.Slice(user.Team, func(i, j int) bool {
|
||||||
@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
|
|
||||||
t.Run("Languages", func(t *testing.T) {
|
t.Run("Languages", func(t *testing.T) {
|
||||||
if len(user.Languages) != len(expect.Languages) {
|
if len(user.Languages) != len(expect.Languages) {
|
||||||
t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
|
t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(user.Languages, func(i, j int) bool {
|
sort.Slice(user.Languages, func(i, j int) bool {
|
||||||
@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
|
|
||||||
t.Run("Friends", func(t *testing.T) {
|
t.Run("Friends", func(t *testing.T) {
|
||||||
if len(user.Friends) != len(expect.Friends) {
|
if len(user.Friends) != len(expect.Friends) {
|
||||||
t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
|
t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(user.Friends, func(i, j int) bool {
|
sort.Slice(user.Friends, func(i, j int) bool {
|
||||||
|
@ -5,12 +5,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Person struct {
|
type Person struct {
|
||||||
ID int
|
ID int
|
||||||
Name string
|
Name string
|
||||||
Addresses []Address `gorm:"many2many:person_addresses;"`
|
Addresses []Address `gorm:"many2many:person_addresses;"`
|
||||||
|
DeletedAt gorm.DeletedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
type Address struct {
|
type Address struct {
|
||||||
@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) {
|
|||||||
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
||||||
t.Fatalf("address should be deleted when clear with unscoped")
|
t.Fatalf("address should be deleted when clear with unscoped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
address2_1 := Address{Name: "address 2-1"}
|
||||||
|
address2_2 := Address{Name: "address 2-2"}
|
||||||
|
person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}}
|
||||||
|
DB.Create(&person2)
|
||||||
|
if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil {
|
||||||
|
t.Fatalf("failed to delete person, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 {
|
||||||
|
t.Errorf("person's addresses expects 2, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 {
|
||||||
|
t.Errorf("person's addresses expects 2, got %v", count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Blog struct {
|
type Blog struct {
|
||||||
@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
|
|||||||
t.Fatalf("EN Blog's tags should be cleared")
|
t.Fatalf("EN Blog's tags should be cleared")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompositePrimaryKeysAssociations(t *testing.T) {
|
||||||
|
type Label struct {
|
||||||
|
BookID *uint `gorm:"primarykey"`
|
||||||
|
Name string `gorm:"primarykey"`
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Book struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Labels []Label
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Label{}, &Book{})
|
||||||
|
if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate")
|
||||||
|
}
|
||||||
|
|
||||||
|
book := Book{
|
||||||
|
Name: "my book",
|
||||||
|
Labels: []Label{
|
||||||
|
{Name: "region", Value: "emea"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&book)
|
||||||
|
|
||||||
|
var result Book
|
||||||
|
if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to preload, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, book, result)
|
||||||
|
}
|
||||||
|
@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) {
|
|||||||
t.Errorf("failed to update with named arg")
|
t.Errorf("failed to update with named arg")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namedUser.Name1 = "jinzhu-new"
|
||||||
|
namedUser.Name2 = "jinzhu-new2"
|
||||||
|
namedUser.Name3 = "jinzhu-new"
|
||||||
|
|
||||||
var result5 NamedUser
|
var result5 NamedUser
|
||||||
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil {
|
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil {
|
||||||
t.Errorf("failed to update with named arg")
|
t.Errorf("failed to update with named arg")
|
||||||
}
|
}
|
||||||
|
|
||||||
AssertEqual(t, result4, namedUser)
|
AssertEqual(t, result5, namedUser)
|
||||||
|
|
||||||
|
var result6 NamedUser
|
||||||
|
if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name
|
||||||
|
AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil {
|
||||||
|
t.Errorf("failed to update with named arg")
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result6, namedUser)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
@ -31,6 +33,20 @@ func TestPreloadWithAssociations(t *testing.T) {
|
|||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID)
|
DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
var user3 = *GetUser("preload_with_associations_new", Config{
|
||||||
|
Account: true,
|
||||||
|
Pets: 2,
|
||||||
|
Toys: 3,
|
||||||
|
Company: true,
|
||||||
|
Manager: true,
|
||||||
|
Team: 4,
|
||||||
|
Languages: 3,
|
||||||
|
Friends: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user3, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNestedPreload(t *testing.T) {
|
func TestNestedPreload(t *testing.T) {
|
||||||
@ -174,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) {
|
|||||||
CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2])
|
CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreloadEmptyData(t *testing.T) {
|
||||||
|
var user = *GetUser("user_without_associations", Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name)
|
||||||
|
|
||||||
|
if r, err := json.Marshal(&user); err != nil {
|
||||||
|
t.Errorf("failed to marshal users, got error %v", err)
|
||||||
|
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
|
||||||
|
t.Errorf("json marshal is not empty slice, got %v", string(r))
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []User
|
||||||
|
DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name)
|
||||||
|
|
||||||
|
if r, err := json.Marshal(&results); err != nil {
|
||||||
|
t.Errorf("failed to marshal users, got error %v", err)
|
||||||
|
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
|
||||||
|
t.Errorf("json marshal is not empty slice, got %v", string(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
52
tests/prepared_stmt_test.go
Normal file
52
tests/prepared_stmt_test.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreparedStmt(t *testing.T) {
|
||||||
|
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||||
|
|
||||||
|
if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok {
|
||||||
|
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
txCtx := tx.WithContext(ctx)
|
||||||
|
|
||||||
|
user := *GetUser("prepared_stmt", Config{})
|
||||||
|
|
||||||
|
txCtx.Create(&user)
|
||||||
|
|
||||||
|
var result1 User
|
||||||
|
if err := txCtx.Find(&result1, user.ID).Error; err != nil {
|
||||||
|
t.Fatalf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
var result2 User
|
||||||
|
if err := tx.Find(&result2, user.ID).Error; err != nil {
|
||||||
|
t.Fatalf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user2 := *GetUser("prepared_stmt2", Config{})
|
||||||
|
if err := txCtx.Create(&user2).Error; err == nil {
|
||||||
|
t.Fatalf("should failed to create with timeout context")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Create(&user2).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result3 User
|
||||||
|
if err := tx.Find(&result3, user2.ID).Error; err != nil {
|
||||||
|
t.Fatalf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
@ -1,11 +1,13 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -61,6 +63,54 @@ func TestFind(t *testing.T) {
|
|||||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if _, ok := first[dbName].(string); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if _, ok := first[dbName].(uint); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if _, ok := first[dbName].(*time.Time); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||||
|
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FirstMapWithTable", func(t *testing.T) {
|
||||||
|
var first = map[string]interface{}{}
|
||||||
|
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query first: %v", err)
|
||||||
|
} else {
|
||||||
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
resultType := reflect.ValueOf(first[dbName]).Type().Name()
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if !strings.Contains(resultType, "string") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if !strings.Contains(resultType, "int") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||||
})
|
})
|
||||||
@ -86,13 +136,29 @@ func TestFind(t *testing.T) {
|
|||||||
t.Run("FirstSliceOfMap", func(t *testing.T) {
|
t.Run("FirstSliceOfMap", func(t *testing.T) {
|
||||||
var allMap = []map[string]interface{}{}
|
var allMap = []map[string]interface{}{}
|
||||||
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||||
t.Errorf("errors happened when query first: %v", err)
|
t.Errorf("errors happened when query find: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for idx, user := range users {
|
for idx, user := range users {
|
||||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if _, ok := allMap[idx][dbName].(string); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if _, ok := allMap[idx][dbName].(uint); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||||
})
|
})
|
||||||
@ -101,6 +167,58 @@ func TestFind(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
|
||||||
|
var allMap = []map[string]interface{}{}
|
||||||
|
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query find: %v", err)
|
||||||
|
} else {
|
||||||
|
for idx, user := range users {
|
||||||
|
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if !strings.Contains(resultType, "string") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if !strings.Contains(resultType, "int") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||||
|
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
var models []User
|
||||||
|
if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 {
|
||||||
|
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models))
|
||||||
|
} else {
|
||||||
|
for idx, user := range users {
|
||||||
|
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
CheckUser(t, models[idx], user)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var none []User
|
||||||
|
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
||||||
|
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryWithAssociation(t *testing.T) {
|
func TestQueryWithAssociation(t *testing.T) {
|
||||||
@ -228,6 +346,11 @@ func TestNot(t *testing.T) {
|
|||||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{})
|
||||||
|
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
result = dryDB.Not("name = ?", "jinzhu").Find(&User{})
|
result = dryDB.Not("name = ?", "jinzhu").Find(&User{})
|
||||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) {
|
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) {
|
||||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||||
@ -309,6 +432,33 @@ func TestPluck(t *testing.T) {
|
|||||||
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var times []time.Time
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tv := range times {
|
||||||
|
AssertEqual(t, tv, users[idx].CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ptrtimes []*time.Time
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tv := range ptrtimes {
|
||||||
|
AssertEqual(t, tv, users[idx].CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nulltimes []sql.NullTime
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tv := range nulltimes {
|
||||||
|
AssertEqual(t, tv.Time, users[idx].CreatedAt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelect(t *testing.T) {
|
func TestSelect(t *testing.T) {
|
||||||
@ -508,6 +658,7 @@ func TestLimit(t *testing.T) {
|
|||||||
{Name: "LimitUser3", Age: 20},
|
{Name: "LimitUser3", Age: 20},
|
||||||
{Name: "LimitUser4", Age: 10},
|
{Name: "LimitUser4", Age: 10},
|
||||||
{Name: "LimitUser5", Age: 20},
|
{Name: "LimitUser5", Age: 20},
|
||||||
|
{Name: "LimitUser6", Age: 20},
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Create(&users)
|
DB.Create(&users)
|
||||||
@ -516,7 +667,7 @@ func TestLimit(t *testing.T) {
|
|||||||
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
||||||
|
|
||||||
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
||||||
t.Errorf("Limit should works")
|
t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -531,6 +682,7 @@ func TestOffset(t *testing.T) {
|
|||||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||||
t.Errorf("Offset should work")
|
t.Errorf("Offset should work")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
||||||
|
|
||||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||||
@ -698,3 +850,11 @@ func TestScanNullValue(t *testing.T) {
|
|||||||
t.Fatalf("failed to query slice data with null age, got error %v", err)
|
t.Fatalf("failed to query slice data with null age, got error %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryWithTableAndConditions(t *testing.T) {
|
||||||
|
result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"})
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,14 +17,25 @@ func TestScan(t *testing.T) {
|
|||||||
DB.Save(&user1).Save(&user2).Save(&user3)
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
type result struct {
|
type result struct {
|
||||||
|
ID uint
|
||||||
Name string
|
Name string
|
||||||
Age int
|
Age int
|
||||||
}
|
}
|
||||||
|
|
||||||
var res result
|
var res result
|
||||||
DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res)
|
DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res)
|
||||||
if res.Name != user3.Name || res.Age != int(user3.Age) {
|
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
|
||||||
t.Errorf("Scan into struct should work")
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res)
|
||||||
|
if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) {
|
||||||
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res)
|
||||||
|
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
|
||||||
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
||||||
}
|
}
|
||||||
|
|
||||||
var doubleAgeRes = &result{}
|
var doubleAgeRes = &result{}
|
||||||
@ -39,11 +51,11 @@ func TestScan(t *testing.T) {
|
|||||||
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
|
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
|
||||||
|
|
||||||
sort.Slice(results, func(i, j int) bool {
|
sort.Slice(results, func(i, j int) bool {
|
||||||
return strings.Compare(results[i].Name, results[j].Name) < -1
|
return strings.Compare(results[i].Name, results[j].Name) <= -1
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
|
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
|
||||||
t.Errorf("Scan into struct map")
|
t.Errorf("Scan into struct map, got %#v", results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,7 +84,21 @@ func TestScanRows(t *testing.T) {
|
|||||||
results = append(results, result)
|
results = append(results, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sort.Slice(results, func(i, j int) bool {
|
||||||
|
return strings.Compare(results[i].Name, results[j].Name) <= -1
|
||||||
|
})
|
||||||
|
|
||||||
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||||
t.Errorf("Should find expected results")
|
t.Errorf("Should find expected results")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ages int
|
||||||
|
if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 {
|
||||||
|
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages)
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
|
||||||
|
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) {
|
|||||||
t.Errorf("expects: %v, got %v", expects, result)
|
t.Errorf("expects: %v, got %v", expects, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCombineStringConditions(t *testing.T) {
|
||||||
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) {
|
||||||
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -13,7 +13,7 @@ if [ -d tests ]
|
|||||||
then
|
then
|
||||||
cd tests
|
cd tests
|
||||||
cp go.mod go.mod.bak
|
cp go.mod go.mod.bak
|
||||||
sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod
|
sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod
|
||||||
cd ..
|
cd ..
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ var DB *gorm.DB
|
|||||||
func init() {
|
func init() {
|
||||||
var err error
|
var err error
|
||||||
if DB, err = OpenTestConnection(); err != nil {
|
if DB, err = OpenTestConnection(); err != nil {
|
||||||
log.Printf("failed to connect database, got error %v\n", err)
|
log.Printf("failed to connect database, got error %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
} else {
|
} else {
|
||||||
sqlDB, err := DB.DB()
|
sqlDB, err := DB.DB()
|
||||||
@ -31,7 +31,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to connect database, got error %v\n", err)
|
log.Printf("failed to connect database, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
RunMigrations()
|
RunMigrations()
|
||||||
|
@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) {
|
|||||||
t.Fatalf("Should find saved record")
|
t.Fatalf("Should find saved record")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransactionOnClosedConn(t *testing.T) {
|
||||||
|
DB, err := OpenTestConnection()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect database, got error %v", err)
|
||||||
|
}
|
||||||
|
rawDB, _ := DB.DB()
|
||||||
|
rawDB.Close()
|
||||||
|
|
||||||
|
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
return nil
|
||||||
|
}); err == nil {
|
||||||
|
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error {
|
||||||
|
return nil
|
||||||
|
}); err == nil {
|
||||||
|
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
|
|||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
user.Company.Name += "new"
|
||||||
|
user.Manager.Name += "new"
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
|||||||
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
for _, pet := range user.Pets {
|
||||||
|
pet.Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
|
|
||||||
t.Run("Polymorphic", func(t *testing.T) {
|
t.Run("Polymorphic", func(t *testing.T) {
|
||||||
var user = *GetUser("update-has-many", Config{})
|
var user = *GetUser("update-has-many", Config{})
|
||||||
|
|
||||||
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
|||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
for idx := range user.Toys {
|
||||||
|
user.Toys[idx].Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
|
|||||||
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
user.Account.Number += "new"
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
|
|
||||||
t.Run("Polymorphic", func(t *testing.T) {
|
t.Run("Polymorphic", func(t *testing.T) {
|
||||||
var pet = Pet{Name: "create"}
|
var pet = Pet{Name: "create"}
|
||||||
|
|
||||||
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
|
|||||||
var pet2 Pet
|
var pet2 Pet
|
||||||
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
|
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
|
||||||
CheckPet(t, pet2, pet)
|
CheckPet(t, pet2, pet)
|
||||||
|
|
||||||
|
pet.Toy.Name += "new"
|
||||||
|
if err := DB.Save(&pet).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pet3 Pet
|
||||||
|
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
|
||||||
|
CheckPet(t, pet2, pet3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pet4 Pet
|
||||||
|
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
|
||||||
|
CheckPet(t, pet4, pet)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
|
|||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
for idx := range user.Friends {
|
||||||
|
user.Friends[idx].Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := range user.Languages {
|
||||||
|
user.Languages[idx].Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
}
|
}
|
||||||
|
@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) {
|
|||||||
var result2 Language
|
var result2 Language
|
||||||
DB.First(&result2, "code = ?", "save")
|
DB.First(&result2, "code = ?", "save")
|
||||||
AssertEqual(t, result2, lang)
|
AssertEqual(t, result2, lang)
|
||||||
|
|
||||||
|
DB.Table("langs").Migrator().DropTable(&Language{})
|
||||||
|
DB.Table("langs").AutoMigrate(&Language{})
|
||||||
|
|
||||||
|
if err := DB.Table("langs").Save(&lang).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when creating data, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result3 Language
|
||||||
|
if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name {
|
||||||
|
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3)
|
||||||
|
}
|
||||||
|
|
||||||
|
lang.Name += "name2"
|
||||||
|
if err := DB.Table("langs").Save(&lang).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when creating data, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result4 Language
|
||||||
|
if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name {
|
||||||
|
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ func FileWithLineNum() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func IsValidDBNameChar(c rune) bool {
|
func IsValidDBNameChar(c rune) bool {
|
||||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
|
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckTruth(val interface{}) bool {
|
func CheckTruth(val interface{}) bool {
|
||||||
@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToString(value interface{}) string {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case int:
|
||||||
|
return strconv.FormatInt(int64(v), 10)
|
||||||
|
case int8:
|
||||||
|
return strconv.FormatInt(int64(v), 10)
|
||||||
|
case int16:
|
||||||
|
return strconv.FormatInt(int64(v), 10)
|
||||||
|
case int32:
|
||||||
|
return strconv.FormatInt(int64(v), 10)
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(v, 10)
|
||||||
|
case uint:
|
||||||
|
return strconv.FormatUint(uint64(v), 10)
|
||||||
|
case uint8:
|
||||||
|
return strconv.FormatUint(uint64(v), 10)
|
||||||
|
case uint16:
|
||||||
|
return strconv.FormatUint(uint64(v), 10)
|
||||||
|
case uint32:
|
||||||
|
return strconv.FormatUint(uint64(v), 10)
|
||||||
|
case uint64:
|
||||||
|
return strconv.FormatUint(v, 10)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user