Merge branch 'master' into patch-column-scan
# Conflicts: # gorm.go
This commit is contained in:
commit
5b0fcc9176
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -3,7 +3,7 @@ on:
|
|||||||
issues:
|
issues:
|
||||||
types: [opened, edited, reopened]
|
types: [opened, edited, reopened]
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, edited, reopened, ready_for_review, synchronize]
|
types: [opened, edited, reopened]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
triage:
|
triage:
|
||||||
|
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
sqlite:
|
sqlite:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.14', '1.13']
|
go: ['1.15', '1.14', '1.13']
|
||||||
platform: [ubuntu-latest, macos-latest] # can not run in windows OS
|
platform: [ubuntu-latest, macos-latest] # can not run in windows OS
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ jobs:
|
|||||||
sqlite_windows:
|
sqlite_windows:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.14', '1.13']
|
go: ['1.15', '1.14', '1.13']
|
||||||
platform: [windows-latest]
|
platform: [windows-latest]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
|
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
|
||||||
go: ['1.14', '1.13']
|
go: ['1.15', '1.14', '1.13']
|
||||||
platform: [ubuntu-latest]
|
platform: [ubuntu-latest]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
|
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
|
||||||
go: ['1.14', '1.13']
|
go: ['1.15', '1.14', '1.13']
|
||||||
platform: [ubuntu-latest] # can not run in macOS and widnowsOS
|
platform: [ubuntu-latest] # can not run in macOS and widnowsOS
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ jobs:
|
|||||||
sqlserver:
|
sqlserver:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.14', '1.13']
|
go: ['1.15', '1.14', '1.13']
|
||||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
|
@ -19,11 +19,11 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||||
* Context, Prepared Statment Mode, DryRun Mode
|
* Context, Prepared Statment Mode, DryRun Mode
|
||||||
* Batch Insert, FindInBatches, Find To Map
|
* Batch Insert, FindInBatches, Find To Map
|
||||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
|
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||||
* Composite Primary Key
|
* Composite Primary Key
|
||||||
* Auto Migrations
|
* Auto Migrations
|
||||||
* Logger
|
* Logger
|
||||||
* Extendable, write Plugins based on GORM callbacks
|
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
|
||||||
* Every feature comes with tests
|
* Every feature comes with tests
|
||||||
* Developer Friendly
|
* Developer Friendly
|
||||||
|
|
||||||
@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
© Jinzhu, 2013~time.Now
|
© Jinzhu, 2013~time.Now
|
||||||
|
|
||||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
||||||
|
|
||||||
|
@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association {
|
|||||||
|
|
||||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
var (
|
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
|
||||||
tx = association.DB.Model(out)
|
|
||||||
)
|
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
|
||||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
|
||||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
|
||||||
joinStmt.AddClause(queryClause)
|
|
||||||
}
|
}
|
||||||
joinStmt.Build("WHERE", "LIMIT")
|
|
||||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
|
||||||
ON: clause.Where{Exprs: queryConds},
|
|
||||||
}}})
|
|
||||||
} else {
|
|
||||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
|
||||||
}
|
|
||||||
|
|
||||||
association.Error = tx.Find(out, conds...).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
return association.Error
|
return association.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error {
|
|||||||
association.Error = association.Replace(values...)
|
association.Error = association.Replace(values...)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
association.saveAssociation(false, values...)
|
association.saveAssociation( /*clear*/ false, values...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error {
|
|||||||
func (association *Association) Replace(values ...interface{}) error {
|
func (association *Association) Replace(values ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
// save associations
|
// save associations
|
||||||
association.saveAssociation(true, values...)
|
association.saveAssociation( /*clear*/ true, values...)
|
||||||
|
|
||||||
// set old associations's foreign key to null
|
// set old associations's foreign key to null
|
||||||
reflectValue := association.DB.Statement.ReflectValue
|
reflectValue := association.DB.Statement.ReflectValue
|
||||||
@ -112,7 +88,7 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
updateMap[ref.ForeignKey.DBName] = nil
|
updateMap[ref.ForeignKey.DBName] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
association.DB.UpdateColumns(updateMap)
|
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||||
}
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
var (
|
var (
|
||||||
@ -170,7 +146,7 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
||||||
tx.Where(clause.IN{Column: column, Values: values})
|
tx.Where(clause.IN{Column: column, Values: values})
|
||||||
} else {
|
} else {
|
||||||
return ErrorPrimaryKeyRequired
|
return ErrPrimaryKeyRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
||||||
@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
var (
|
var (
|
||||||
primaryFields, relPrimaryFields []*schema.Field
|
primaryFields, relPrimaryFields []*schema.Field
|
||||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
|
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
|
// clean up deleted values's foreign key
|
||||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||||
|
|
||||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||||
@ -328,33 +305,8 @@ func (association *Association) Clear() error {
|
|||||||
|
|
||||||
func (association *Association) Count() (count int64) {
|
func (association *Association) Count() (count int64) {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
var (
|
association.Error = association.buildCondition().Count(&count).Error
|
||||||
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
|
||||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
|
||||||
tx = association.DB.Model(modelValue)
|
|
||||||
)
|
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
|
||||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
|
||||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
|
||||||
joinStmt.AddClause(queryClause)
|
|
||||||
}
|
}
|
||||||
joinStmt.Build("WHERE", "LIMIT")
|
|
||||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
|
||||||
ON: clause.Where{Exprs: conds},
|
|
||||||
}}})
|
|
||||||
} else {
|
|
||||||
tx.Clauses(clause.Where{Exprs: conds})
|
|
||||||
}
|
|
||||||
|
|
||||||
association.Error = tx.Count(&count).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if len(values) != reflectValue.Len() {
|
if len(values) != reflectValue.Len() {
|
||||||
|
// clear old data
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||||
@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
// clear old data
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
|
|
||||||
@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (association *Association) buildCondition() *DB {
|
||||||
|
var (
|
||||||
|
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||||
|
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||||
|
tx = association.DB.Model(modelValue)
|
||||||
|
)
|
||||||
|
|
||||||
|
if association.Relationship.JoinTable != nil {
|
||||||
|
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||||
|
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||||
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
|
joinStmt.AddClause(queryClause)
|
||||||
|
}
|
||||||
|
joinStmt.Build("WHERE")
|
||||||
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||||
|
ON: clause.Where{Exprs: queryConds},
|
||||||
|
}}})
|
||||||
|
} else {
|
||||||
|
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
16
callbacks.go
16
callbacks.go
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
names, sorted []string
|
names, sorted []string
|
||||||
sortCallback func(*callback) error
|
sortCallback func(*callback) error
|
||||||
)
|
)
|
||||||
|
sort.Slice(cs, func(i, j int) bool {
|
||||||
|
return cs[j].before == "*" || cs[j].after == "*"
|
||||||
|
})
|
||||||
|
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
|
|
||||||
sortCallback = func(c *callback) error {
|
sortCallback = func(c *callback) error {
|
||||||
if c.before != "" { // if defined before callback
|
if c.before != "" { // if defined before callback
|
||||||
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
if c.before == "*" && len(sorted) > 0 {
|
||||||
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
|
sorted = append([]string{c.name}, sorted...)
|
||||||
|
}
|
||||||
|
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
// if before callback already sorted, append current callback just after it
|
// if before callback already sorted, append current callback just after it
|
||||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||||
@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.after != "" { // if defined after callback
|
if c.after != "" { // if defined after callback
|
||||||
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
if c.after == "*" && len(sorted) > 0 {
|
||||||
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
|
sorted = append(sorted, c.name)
|
||||||
|
}
|
||||||
|
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
// if after callback sorted, append current callback to last
|
// if after callback sorted, append current callback to last
|
||||||
sorted = append(sorted, c.name)
|
sorted = append(sorted, c.name)
|
||||||
|
@ -48,6 +48,8 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
|
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
obj := db.Statement.ReflectValue.Index(i)
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
|
|
||||||
|
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||||
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
||||||
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
||||||
objs = append(objs, obj)
|
objs = append(objs, obj)
|
||||||
@ -57,6 +59,9 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||||||
elems = reflect.Append(elems, rv.Addr())
|
elems = reflect.Append(elems, rv.Addr())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
@ -112,6 +117,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
obj := db.Statement.ReflectValue.Index(i)
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
|
|
||||||
|
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||||
if _, zero := rel.Field.ValueOf(obj); !zero {
|
if _, zero := rel.Field.ValueOf(obj); !zero {
|
||||||
rv := rel.Field.ReflectValueOf(obj)
|
rv := rel.Field.ReflectValueOf(obj)
|
||||||
if rv.Kind() != reflect.Ptr {
|
if rv.Kind() != reflect.Ptr {
|
||||||
@ -130,6 +136,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
elems = reflect.Append(elems, rv)
|
elems = reflect.Append(elems, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
assignmentColumns := []string{}
|
assignmentColumns := []string{}
|
||||||
@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
appendToElems(db.Statement.ReflectValue.Index(i))
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
|
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||||
|
appendToElems(obj)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
appendToElems(db.Statement.ReflectValue)
|
appendToElems(db.Statement.ReflectValue)
|
||||||
@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
appendToElems(db.Statement.ReflectValue.Index(i))
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
|
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||||
|
appendToElems(obj)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
appendToElems(db.Statement.ReflectValue)
|
appendToElems(db.Statement.ReflectValue)
|
||||||
|
@ -45,6 +45,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
|
|
||||||
db.Callback().Row().Register("gorm:raw", RowQuery)
|
db.Callback().Row().Register("gorm:row", RowQuery)
|
||||||
db.Callback().Raw().Register("gorm:raw", RawExec)
|
db.Callback().Raw().Register("gorm:raw", RawExec)
|
||||||
}
|
}
|
||||||
|
@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
|
|||||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
if i, ok := value.(BeforeSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeSave(tx))
|
db.AddError(i.BeforeSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.BeforeCreate {
|
if db.Statement.Schema.BeforeCreate {
|
||||||
if i, ok := value.(gorm.BeforeCreateInterface); ok {
|
if i, ok := value.(BeforeCreateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeCreate(tx))
|
db.AddError(i.BeforeCreate(tx))
|
||||||
}
|
}
|
||||||
@ -43,15 +43,13 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
if db.Statement.SQL.String() == "" {
|
||||||
db.Statement.SQL.Grow(180)
|
db.Statement.SQL.Grow(180)
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
|
||||||
})
|
|
||||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||||
|
|
||||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun && db.Error == nil {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -63,31 +61,35 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if config.LastInsertIDReversed {
|
if config.LastInsertIDReversed {
|
||||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i))
|
rv := db.Statement.ReflectValue.Index(i)
|
||||||
|
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||||
insertID--
|
insertID--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
allUpdated := int(db.RowsAffected) == db.Statement.ReflectValue.Len()
|
|
||||||
isZero := true
|
|
||||||
|
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
|
rv := db.Statement.ReflectValue.Index(i)
|
||||||
if !allUpdated {
|
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||||
_, isZero = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i))
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if isZero {
|
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||||
insertID++
|
insertID++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
if insertID > 0 {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
@ -111,9 +113,7 @@ func CreateWithReturning(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
if db.Statement.SQL.String() == "" {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
|
||||||
})
|
|
||||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||||
|
|
||||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||||
@ -136,7 +136,7 @@ func CreateWithReturning(db *gorm.DB) {
|
|||||||
db.Statement.WriteQuoted(field.DBName)
|
db.Statement.WriteQuoted(field.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun && db.Error == nil {
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -149,10 +149,21 @@ func CreateWithReturning(db *gorm.DB) {
|
|||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
BEGIN:
|
BEGIN:
|
||||||
|
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
|
||||||
|
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
fieldValue := field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected)))
|
fieldValue := field.ReflectValueOf(reflectValue)
|
||||||
|
|
||||||
if onConflict.DoNothing && !fieldValue.IsZero() {
|
if onConflict.DoNothing && !fieldValue.IsZero() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
|
|
||||||
|
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
goto BEGIN
|
goto BEGIN
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,7 +189,7 @@ func CreateWithReturning(db *gorm.DB) {
|
|||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if !db.DryRun {
|
} else if !db.DryRun && db.Error == nil {
|
||||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
} else {
|
} else {
|
||||||
@ -192,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
|
|||||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterSave(tx))
|
db.AddError(i.AfterSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.AfterCreate {
|
if db.Statement.Schema.AfterCreate {
|
||||||
if i, ok := value.(gorm.AfterCreateInterface); ok {
|
if i, ok := value.(AfterCreateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterCreate(tx))
|
db.AddError(i.AfterCreate(tx))
|
||||||
}
|
}
|
||||||
@ -214,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
switch value := stmt.Dest.(type) {
|
switch value := stmt.Dest.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
values = ConvertMapToValuesForCreate(stmt, value)
|
values = ConvertMapToValuesForCreate(stmt, value)
|
||||||
|
case *map[string]interface{}:
|
||||||
|
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||||
|
case *[]map[string]interface{}:
|
||||||
|
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
@ -298,6 +313,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
stmt.AddError(gorm.ErrInvalidData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
func BeforeDelete(db *gorm.DB) {
|
func BeforeDelete(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||||
db.AddError(i.BeforeDelete(tx))
|
db.AddError(i.BeforeDelete(tx))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -41,7 +41,7 @@ func Delete(db *gorm.DB) {
|
|||||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
@ -51,16 +51,16 @@ func Delete(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
|
||||||
db.AddError(gorm.ErrMissingWhereClause)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
db.Statement.Build("DELETE", "FROM", "WHERE")
|
db.Statement.Build("DELETE", "FROM", "WHERE")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||||
|
db.AddError(gorm.ErrMissingWhereClause)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !db.DryRun && db.Error == nil {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
|
|||||||
func AfterDelete(db *gorm.DB) {
|
func AfterDelete(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
if i, ok := value.(AfterDeleteInterface); ok {
|
||||||
db.AddError(i.AfterDelete(tx))
|
db.AddError(i.AfterDelete(tx))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -20,12 +20,18 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
|
|||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
value := mapValue[k]
|
value := mapValue[k]
|
||||||
|
if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
k = field.DBName
|
k = field.DBName
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||||
|
if len(values.Values) == 0 {
|
||||||
|
values.Values = [][]interface{}{{}}
|
||||||
|
}
|
||||||
|
|
||||||
values.Values[0] = append(values.Values[0], value)
|
values.Values[0] = append(values.Values[0], value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -42,9 +48,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
|||||||
|
|
||||||
for idx, mapValue := range mapValues {
|
for idx, mapValue := range mapValues {
|
||||||
for k, v := range mapValue {
|
for k, v := range mapValue {
|
||||||
|
if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
k = field.DBName
|
k = field.DBName
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := result[k]; !ok {
|
if _, ok := result[k]; !ok {
|
||||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
@ -61,11 +69,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
|||||||
|
|
||||||
sort.Strings(columns)
|
sort.Strings(columns)
|
||||||
values.Values = make([][]interface{}, len(mapValues))
|
values.Values = make([][]interface{}, len(mapValues))
|
||||||
|
values.Columns = make([]clause.Column, len(columns))
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
|
values.Columns[idx] = clause.Column{Name: column}
|
||||||
|
|
||||||
for i, v := range result[column] {
|
for i, v := range result[column] {
|
||||||
if i == 0 {
|
if len(values.Values[i]) == 0 {
|
||||||
values.Values[i] = make([]interface{}, len(columns))
|
values.Values[i] = make([]interface{}, len(columns))
|
||||||
}
|
}
|
||||||
|
|
||||||
values.Values[i][idx] = v
|
values.Values[i][idx] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
39
callbacks/interfaces.go
Normal file
39
callbacks/interfaces.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
|
type BeforeCreateInterface interface {
|
||||||
|
BeforeCreate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterCreateInterface interface {
|
||||||
|
AfterCreate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeforeUpdateInterface interface {
|
||||||
|
BeforeUpdate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterUpdateInterface interface {
|
||||||
|
AfterUpdate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeforeSaveInterface interface {
|
||||||
|
BeforeSave(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterSaveInterface interface {
|
||||||
|
AfterSave(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeforeDeleteInterface interface {
|
||||||
|
BeforeDelete(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterDeleteInterface interface {
|
||||||
|
AfterDelete(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterFindInterface interface {
|
||||||
|
AfterFind(*gorm.DB) error
|
||||||
|
}
|
@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
|||||||
|
|
||||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||||
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
|
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
|
||||||
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
|
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
|
||||||
|
|
||||||
// convert join identity map to relation identity map
|
// convert join identity map to relation identity map
|
||||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||||
@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues)
|
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||||
|
|
||||||
for _, cond := range conds {
|
for _, cond := range conds {
|
||||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||||
@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...)
|
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
|
||||||
|
|
||||||
fieldValues := make([]interface{}, len(relForeignFields))
|
fieldValues := make([]interface{}, len(relForeignFields))
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ func Query(db *gorm.DB) {
|
|||||||
BuildQuerySQL(db)
|
BuildQuerySQL(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun && db.Error == nil {
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
// inline joins
|
// inline joins
|
||||||
if len(db.Statement.Joins) != 0 {
|
if len(db.Statement.Joins) != 0 {
|
||||||
if len(db.Statement.Selects) == 0 {
|
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
|
||||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||||
@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
joins := []clause.Join{}
|
joins := []clause.Join{}
|
||||||
for name, conds := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema == nil {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Expression: clause.Expr{SQL: name, Vars: conds},
|
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
|
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||||
tableAliasName := relation.Name
|
tableAliasName := relation.Name
|
||||||
|
|
||||||
for _, s := range relation.FieldSchema.DBNames {
|
for _, s := range relation.FieldSchema.DBNames {
|
||||||
@ -124,13 +124,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
for idx, ref := range relation.References {
|
for idx, ref := range relation.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
exprs[idx] = clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
|
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ref.PrimaryValue == "" {
|
if ref.PrimaryValue == "" {
|
||||||
exprs[idx] = clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
|
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Expression: clause.Expr{SQL: name, Vars: conds},
|
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
|
|||||||
func AfterQuery(db *gorm.DB) {
|
func AfterQuery(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
if i, ok := value.(AfterFindInterface); ok {
|
||||||
db.AddError(i.AfterFind(tx))
|
db.AddError(i.AfterFind(tx))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
|
|||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
if i, ok := value.(BeforeSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeSave(tx))
|
db.AddError(i.BeforeSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.BeforeUpdate {
|
if db.Statement.Schema.BeforeUpdate {
|
||||||
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeUpdate(tx))
|
db.AddError(i.BeforeUpdate(tx))
|
||||||
}
|
}
|
||||||
@ -69,12 +69,12 @@ func Update(db *gorm.DB) {
|
|||||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||||
db.AddError(gorm.ErrMissingWhereClause)
|
db.AddError(gorm.ErrMissingWhereClause)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun && db.Error == nil {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
|
|||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterSave(tx))
|
db.AddError(i.AfterSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.AfterUpdate {
|
if db.Statement.Schema.AfterUpdate {
|
||||||
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
if i, ok := value.(AfterUpdateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterUpdate(tx))
|
db.AddError(i.AfterUpdate(tx))
|
||||||
}
|
}
|
||||||
@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var priamryKeyExprs []clause.Expression
|
var primaryKeyExprs []clause.Expression
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||||
var notZero bool
|
var notZero bool
|
||||||
@ -150,10 +150,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
notZero = notZero || !isZero
|
notZero = notZero || !isZero
|
||||||
}
|
}
|
||||||
if notZero {
|
if notZero {
|
||||||
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
|
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||||
@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
|
kv := value[k]
|
||||||
|
if _, ok := kv.(*gorm.DB); ok {
|
||||||
|
kv = []interface{}{kv}
|
||||||
|
}
|
||||||
|
|
||||||
if stmt.Schema != nil {
|
if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
||||||
assignValue(field, value[k])
|
assignValue(field, value[k])
|
||||||
}
|
}
|
||||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||||
@ -189,12 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stmt.UpdatingColumn && stmt.Schema != nil {
|
if !stmt.UpdatingColumn && stmt.Schema != nil {
|
||||||
for _, field := range stmt.Schema.FieldsByDBName {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
|
field := stmt.Schema.LookUpField(dbName)
|
||||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
now := stmt.DB.NowFunc()
|
now := stmt.DB.NowFunc()
|
||||||
@ -202,6 +208,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
|
|
||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||||
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||||
} else if field.GORMDataType == schema.Time {
|
} else if field.GORMDataType == schema.Time {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||||
} else {
|
} else {
|
||||||
@ -215,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
switch updatingValue.Kind() {
|
switch updatingValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||||
for _, field := range stmt.Schema.FieldsByDBName {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
|
field := stmt.Schema.LookUpField(dbName)
|
||||||
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
value, isZero := field.ValueOf(updatingValue)
|
value, isZero := field.ValueOf(updatingValue)
|
||||||
@ -223,6 +232,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
if field.AutoUpdateTime > 0 {
|
if field.AutoUpdateTime > 0 {
|
||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano()
|
value = stmt.DB.NowFunc().UnixNano()
|
||||||
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
|
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||||
} else if field.GORMDataType == schema.Time {
|
} else if field.GORMDataType == schema.Time {
|
||||||
value = stmt.DB.NowFunc()
|
value = stmt.DB.NowFunc()
|
||||||
} else {
|
} else {
|
||||||
@ -243,6 +254,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
stmt.AddError(gorm.ErrInvalidData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
|||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
tx = tx.Select(args[0], args[1:]...)
|
tx = tx.Select(args[0], args[1:]...)
|
||||||
}
|
}
|
||||||
return tx
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select specify fields that you want when querying, creating, updating
|
// Select specify fields that you want when querying, creating, updating
|
||||||
@ -91,8 +91,9 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
delete(tx.Statement.Clauses, "SELECT")
|
||||||
case string:
|
case string:
|
||||||
fields := strings.FieldsFunc(v, utils.IsChar)
|
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
|
||||||
|
|
||||||
// normal field names
|
// normal field names
|
||||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||||
@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete(tx.Statement.Clauses, "SELECT")
|
||||||
} else {
|
} else {
|
||||||
tx.Statement.AddClause(clause.Select{
|
tx.Statement.AddClause(clause.Select{
|
||||||
Distinct: db.Statement.Distinct,
|
Distinct: db.Statement.Distinct,
|
||||||
@ -130,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
|
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||||
} else {
|
} else {
|
||||||
tx.Statement.Omits = columns
|
tx.Statement.Omits = columns
|
||||||
}
|
}
|
||||||
@ -169,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if tx.Statement.Joins == nil {
|
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||||
tx.Statement.Joins = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
tx.Statement.Joins[query] = args
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
|||||||
func (db *DB) Group(name string) (tx *DB) {
|
func (db *DB) Group(name string) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
fields := strings.FieldsFunc(name, utils.IsChar)
|
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
|
||||||
tx.Statement.AddClause(clause.GroupBy{
|
tx.Statement.AddClause(clause.GroupBy{
|
||||||
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
||||||
})
|
})
|
||||||
|
16
errors.go
16
errors.go
@ -7,24 +7,20 @@ import (
|
|||||||
var (
|
var (
|
||||||
// ErrRecordNotFound record not found error
|
// ErrRecordNotFound record not found error
|
||||||
ErrRecordNotFound = errors.New("record not found")
|
ErrRecordNotFound = errors.New("record not found")
|
||||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
|
||||||
ErrInvalidSQL = errors.New("invalid SQL")
|
|
||||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||||
// ErrUnaddressable unaddressable value
|
|
||||||
ErrUnaddressable = errors.New("using unaddressable value")
|
|
||||||
// ErrNotImplemented not implemented
|
// ErrNotImplemented not implemented
|
||||||
ErrNotImplemented = errors.New("not implemented")
|
ErrNotImplemented = errors.New("not implemented")
|
||||||
// ErrMissingWhereClause missing where clause
|
// ErrMissingWhereClause missing where clause
|
||||||
ErrMissingWhereClause = errors.New("WHERE conditions required")
|
ErrMissingWhereClause = errors.New("WHERE conditions required")
|
||||||
// ErrUnsupportedRelation unsupported relations
|
// ErrUnsupportedRelation unsupported relations
|
||||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||||
// ErrPtrStructSupported only ptr of struct supported
|
// ErrPrimaryKeyRequired primary keys required
|
||||||
ErrPtrStructSupported = errors.New("only ptr of struct supported")
|
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||||
// ErrorPrimaryKeyRequired primary keys required
|
// ErrModelValueRequired model value required
|
||||||
ErrorPrimaryKeyRequired = errors.New("primary key required")
|
ErrModelValueRequired = errors.New("model value required")
|
||||||
// ErrorModelValueRequired model value required
|
// ErrInvalidData unsupported data
|
||||||
ErrorModelValueRequired = errors.New("model value required")
|
ErrInvalidData = errors.New("unsupported data")
|
||||||
// ErrUnsupportedDriver unsupported driver
|
// ErrUnsupportedDriver unsupported driver
|
||||||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||||
// ErrRegistered registered
|
// ErrRegistered registered
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,8 +133,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||||
for _, expr := range exprs {
|
for _, value := range values {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []clause.Expression:
|
||||||
|
for _, expr := range v {
|
||||||
if eq, ok := expr.(clause.Eq); ok {
|
if eq, ok := expr.(clause.Eq); ok {
|
||||||
switch column := eq.Column.(type) {
|
switch column := eq.Column.(type) {
|
||||||
case string:
|
case string:
|
||||||
@ -144,7 +148,31 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
|||||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||||
|
exprs := tx.Statement.BuildCondition(value)
|
||||||
|
tx.assignInterfacesToValue(exprs)
|
||||||
default:
|
default:
|
||||||
|
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
for _, f := range s.Fields {
|
||||||
|
if f.Readable {
|
||||||
|
if v, isZero := f.ValueOf(reflectValue); !isZero {
|
||||||
|
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||||
|
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if len(values) > 0 {
|
||||||
|
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
|
||||||
|
tx.assignInterfacesToValue(exprs)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,22 +182,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
||||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
tx.assignExprsToValue(where.Exprs)
|
tx.assignInterfacesToValue(where.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.attrs) > 0 {
|
if len(tx.Statement.attrs) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
tx.Error = nil
|
tx.Error = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.assigns) > 0 {
|
if len(tx.Statement.assigns) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -180,20 +206,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
|
|
||||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
tx.assignExprsToValue(where.Exprs)
|
tx.assignInterfacesToValue(where.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.attrs) > 0 {
|
if len(tx.Statement.attrs) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.assigns) > 0 {
|
if len(tx.Statement.assigns) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Create(dest)
|
return tx.Create(dest)
|
||||||
@ -265,20 +289,30 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if tx.Statement.Model == nil {
|
if tx.Statement.Model == nil {
|
||||||
tx.Statement.Model = tx.Statement.Dest
|
tx.Statement.Model = tx.Statement.Dest
|
||||||
|
defer func() {
|
||||||
|
tx.Statement.Model = nil
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tx.Statement.Selects) == 0 {
|
if len(tx.Statement.Selects) == 0 {
|
||||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||||
defer tx.Statement.AddClause(clause.Select{})
|
defer delete(tx.Statement.Clauses, "SELECT")
|
||||||
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
||||||
expr := clause.Expr{SQL: "count(1)"}
|
expr := clause.Expr{SQL: "count(1)"}
|
||||||
|
|
||||||
if len(tx.Statement.Selects) == 1 {
|
if len(tx.Statement.Selects) == 1 {
|
||||||
|
dbName := tx.Statement.Selects[0]
|
||||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||||
if f := tx.Statement.Schema.LookUpField(tx.Statement.Selects[0]); f != nil {
|
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||||
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: f.DBName}}}
|
dbName = f.DBName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tx.Statement.Distinct {
|
||||||
|
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||||
|
} else {
|
||||||
|
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Statement.AddClause(clause.Select{Expression: expr})
|
tx.Statement.AddClause(clause.Select{Expression: expr})
|
||||||
@ -325,10 +359,10 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if tx.Statement.Table == "" {
|
} else if tx.Statement.Table == "" {
|
||||||
tx.AddError(ErrorModelValueRequired)
|
tx.AddError(ErrModelValueRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
fields := strings.FieldsFunc(column, utils.IsChar)
|
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||||
Distinct: tx.Statement.Distinct,
|
Distinct: tx.Statement.Distinct,
|
||||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||||
@ -413,7 +447,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
|
|
||||||
// Commit commit a transaction
|
// Commit commit a transaction
|
||||||
func (db *DB) Commit() *DB {
|
func (db *DB) Commit() *DB {
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||||
db.AddError(committer.Commit())
|
db.AddError(committer.Commit())
|
||||||
} else {
|
} else {
|
||||||
db.AddError(ErrInvalidTransaction)
|
db.AddError(ErrInvalidTransaction)
|
||||||
@ -424,7 +458,9 @@ func (db *DB) Commit() *DB {
|
|||||||
// Rollback rollback a transaction
|
// Rollback rollback a transaction
|
||||||
func (db *DB) Rollback() *DB {
|
func (db *DB) Rollback() *DB {
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
|
if !reflect.ValueOf(committer).IsNil() {
|
||||||
db.AddError(committer.Rollback())
|
db.AddError(committer.Rollback())
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(ErrInvalidTransaction)
|
db.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
|
21
gorm.go
21
gorm.go
@ -36,6 +36,8 @@ type Config struct {
|
|||||||
Workarounds struct {
|
Workarounds struct {
|
||||||
DoColumnLowerCasing bool
|
DoColumnLowerCasing bool
|
||||||
}
|
}
|
||||||
|
// AllowGlobalUpdate allow global update
|
||||||
|
AllowGlobalUpdate bool
|
||||||
|
|
||||||
// ClauseBuilders clause builder
|
// ClauseBuilders clause builder
|
||||||
ClauseBuilders map[string]clause.ClauseBuilder
|
ClauseBuilders map[string]clause.ClauseBuilder
|
||||||
@ -65,6 +67,7 @@ type Session struct {
|
|||||||
PrepareStmt bool
|
PrepareStmt bool
|
||||||
WithConditions bool
|
WithConditions bool
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
|
AllowGlobalUpdate bool
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
NowFunc func() time.Time
|
NowFunc func() time.Time
|
||||||
@ -112,11 +115,15 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||||||
err = config.Dialector.Initialize(db)
|
err = config.Dialector.Initialize(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PrepareStmt {
|
preparedStmt := &PreparedStmtDB{
|
||||||
db.ConnPool = &PreparedStmtDB{
|
|
||||||
ConnPool: db.ConnPool,
|
ConnPool: db.ConnPool,
|
||||||
Stmts: map[string]*sql.Stmt{},
|
Stmts: map[string]*sql.Stmt{},
|
||||||
|
PreparedSQL: make([]string, 0, 100),
|
||||||
}
|
}
|
||||||
|
db.cacheStore.Store("preparedStmt", preparedStmt)
|
||||||
|
|
||||||
|
if config.PrepareStmt {
|
||||||
|
db.ConnPool = preparedStmt
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement = &Statement{
|
db.Statement = &Statement{
|
||||||
@ -154,6 +161,10 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
tx.Config.SkipDefaultTransaction = true
|
tx.Config.SkipDefaultTransaction = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.AllowGlobalUpdate {
|
||||||
|
txConfig.AllowGlobalUpdate = true
|
||||||
|
}
|
||||||
|
|
||||||
if config.Context != nil {
|
if config.Context != nil {
|
||||||
tx.Statement = tx.Statement.clone()
|
tx.Statement = tx.Statement.clone()
|
||||||
tx.Statement.DB = tx
|
tx.Statement.DB = tx
|
||||||
@ -161,9 +172,13 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.PrepareStmt {
|
if config.PrepareStmt {
|
||||||
|
if v, ok := db.cacheStore.Load("preparedStmt"); ok {
|
||||||
|
preparedStmt := v.(*PreparedStmtDB)
|
||||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
ConnPool: db.Config.ConnPool,
|
ConnPool: db.Config.ConnPool,
|
||||||
Stmts: map[string]*sql.Stmt{},
|
Mux: preparedStmt.Mux,
|
||||||
|
Stmts: preparedStmt.Stmts,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,38 +53,7 @@ type TxCommitter interface {
|
|||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type BeforeCreateInterface interface {
|
// Valuer gorm valuer interface
|
||||||
BeforeCreate(*DB) error
|
type Valuer interface {
|
||||||
}
|
GormValue(context.Context, *DB) clause.Expr
|
||||||
|
|
||||||
type AfterCreateInterface interface {
|
|
||||||
AfterCreate(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type BeforeUpdateInterface interface {
|
|
||||||
BeforeUpdate(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterUpdateInterface interface {
|
|
||||||
AfterUpdate(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type BeforeSaveInterface interface {
|
|
||||||
BeforeSave(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterSaveInterface interface {
|
|
||||||
AfterSave(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type BeforeDeleteInterface interface {
|
|
||||||
BeforeDelete(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterDeleteInterface interface {
|
|
||||||
AfterDelete(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterFindInterface interface {
|
|
||||||
AfterFind(*DB) error
|
|
||||||
}
|
}
|
||||||
|
@ -50,14 +50,8 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
case string:
|
case string:
|
||||||
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
|
||||||
default:
|
default:
|
||||||
if v == nil {
|
|
||||||
vars[idx] = "NULL"
|
|
||||||
} else {
|
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
|
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||||
if !rv.IsValid() {
|
|
||||||
vars[idx] = "NULL"
|
|
||||||
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
|
||||||
vars[idx] = "NULL"
|
vars[idx] = "NULL"
|
||||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
v, _ = valuer.Value()
|
v, _ = valuer.Value()
|
||||||
@ -71,12 +65,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for idx, v := range vars {
|
for idx, v := range vars {
|
||||||
convertParams(v, idx)
|
convertParams(v, idx)
|
||||||
|
@ -42,6 +42,7 @@ type Migrator interface {
|
|||||||
AddColumn(dst interface{}, field string) error
|
AddColumn(dst interface{}, field string) error
|
||||||
DropColumn(dst interface{}, field string) error
|
DropColumn(dst interface{}, field string) error
|
||||||
AlterColumn(dst interface{}, field string) error
|
AlterColumn(dst interface{}, field string) error
|
||||||
|
MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error
|
||||||
HasColumn(dst interface{}, field string) bool
|
HasColumn(dst interface{}, field string) bool
|
||||||
RenameColumn(dst interface{}, oldName, field string) error
|
RenameColumn(dst interface{}, oldName, field string) error
|
||||||
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
|
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||||||
|
|
||||||
// AutoMigrate
|
// AutoMigrate
|
||||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
// TODO smart migrate data type
|
|
||||||
for _, value := range m.ReorderModels(values, true) {
|
for _, value := range m.ReorderModels(values, true) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
if !tx.Migrator().HasTable(value) {
|
if !tx.Migrator().HasTable(value) {
|
||||||
@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||||
|
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||||
|
|
||||||
for _, field := range stmt.Schema.FieldsByDBName {
|
for _, field := range stmt.Schema.FieldsByDBName {
|
||||||
if !tx.Migrator().HasColumn(value, field.DBName) {
|
var foundColumn *sql.ColumnType
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
if columnType.Name() == field.DBName {
|
||||||
|
foundColumn = columnType
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundColumn == nil {
|
||||||
|
// not found, add column
|
||||||
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||||
|
// found, smart migrate
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error {
|
||||||
|
// found, smart migrate
|
||||||
|
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
||||||
|
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||||
|
|
||||||
|
alterColumn := false
|
||||||
|
|
||||||
|
// check size
|
||||||
|
if length, _ := columnType.Length(); length != int64(field.Size) {
|
||||||
|
if length > 0 && field.Size > 0 {
|
||||||
|
alterColumn = true
|
||||||
|
} else {
|
||||||
|
// has size in data type and not equal
|
||||||
|
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
|
||||||
|
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||||
|
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check precision
|
||||||
|
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||||
|
if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check nullable
|
||||||
|
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||||
|
// not primary key & database is nullable
|
||||||
|
if !field.PrimaryKey && nullable {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if alterColumn {
|
||||||
|
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
||||||
|
@ -8,40 +8,44 @@ import (
|
|||||||
|
|
||||||
type PreparedStmtDB struct {
|
type PreparedStmtDB struct {
|
||||||
Stmts map[string]*sql.Stmt
|
Stmts map[string]*sql.Stmt
|
||||||
mux sync.RWMutex
|
PreparedSQL []string
|
||||||
|
Mux sync.RWMutex
|
||||||
ConnPool
|
ConnPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) Close() {
|
func (db *PreparedStmtDB) Close() {
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
for k, stmt := range db.Stmts {
|
for _, query := range db.PreparedSQL {
|
||||||
delete(db.Stmts, k)
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
|
delete(db.Stmts, query)
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||||
db.mux.RLock()
|
db.Mux.RLock()
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
db.mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
return stmt, nil
|
return stmt, nil
|
||||||
}
|
}
|
||||||
db.mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
|
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
// double check
|
// double check
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
return stmt, nil
|
return stmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
|
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
db.Stmts[query] = stmt
|
db.Stmts[query] = stmt
|
||||||
|
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||||
}
|
}
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
|
|
||||||
return stmt, err
|
return stmt, err
|
||||||
}
|
}
|
||||||
@ -59,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = stmt.ExecContext(ctx, args...)
|
result, err = stmt.ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(db.Stmts, query)
|
delete(db.Stmts, query)
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
@ -73,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = stmt.QueryContext(ctx, args...)
|
rows, err = stmt.QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(db.Stmts, query)
|
delete(db.Stmts, query)
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rows, err
|
return rows, err
|
||||||
@ -100,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(tx.PreparedStmtDB.Stmts, query)
|
delete(tx.PreparedStmtDB.Stmts, query)
|
||||||
tx.PreparedStmtDB.mux.Unlock()
|
tx.PreparedStmtDB.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
@ -114,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(tx.PreparedStmtDB.Stmts, query)
|
delete(tx.PreparedStmtDB.Stmts, query)
|
||||||
tx.PreparedStmtDB.mux.Unlock()
|
tx.PreparedStmtDB.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rows, err
|
return rows, err
|
||||||
|
157
schema/field.go
157
schema/field.go
@ -20,7 +20,8 @@ type TimeType int64
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
UnixSecond TimeType = 1
|
UnixSecond TimeType = 1
|
||||||
UnixNanosecond TimeType = 2
|
UnixMillisecond TimeType = 2
|
||||||
|
UnixNanosecond TimeType = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -54,6 +55,7 @@ type Field struct {
|
|||||||
Comment string
|
Comment string
|
||||||
Size int
|
Size int
|
||||||
Precision int
|
Precision int
|
||||||
|
Scale int
|
||||||
FieldType reflect.Type
|
FieldType reflect.Type
|
||||||
IndirectFieldType reflect.Type
|
IndirectFieldType reflect.Type
|
||||||
StructField reflect.StructField
|
StructField reflect.StructField
|
||||||
@ -61,6 +63,7 @@ type Field struct {
|
|||||||
TagSettings map[string]string
|
TagSettings map[string]string
|
||||||
Schema *Schema
|
Schema *Schema
|
||||||
EmbeddedSchema *Schema
|
EmbeddedSchema *Schema
|
||||||
|
OwnerSchema *Schema
|
||||||
ReflectValueOf func(reflect.Value) reflect.Value
|
ReflectValueOf func(reflect.Value) reflect.Value
|
||||||
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
||||||
Set func(reflect.Value, interface{}) error
|
Set func(reflect.Value, interface{}) error
|
||||||
@ -86,46 +89,34 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fieldValue := reflect.New(field.IndirectFieldType)
|
fieldValue := reflect.New(field.IndirectFieldType)
|
||||||
|
|
||||||
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
|
|
||||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
|
|
||||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
|
|
||||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
|
|
||||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if field is valuer, used its value or first fields as data type
|
// if field is valuer, used its value or first fields as data type
|
||||||
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
|
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
|
||||||
if isValuer {
|
if isValuer {
|
||||||
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
|
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
|
||||||
var overrideFieldValue bool
|
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
|
||||||
if v, err := valuer.Value(); v != nil && err == nil {
|
|
||||||
overrideFieldValue = true
|
|
||||||
fieldValue = reflect.ValueOf(v)
|
fieldValue = reflect.ValueOf(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.IndirectFieldType.Kind() == reflect.Struct {
|
var getRealFieldValue func(reflect.Value)
|
||||||
for i := 0; i < field.IndirectFieldType.NumField(); i++ {
|
getRealFieldValue = func(v reflect.Value) {
|
||||||
if !overrideFieldValue {
|
rv := reflect.Indirect(v)
|
||||||
newFieldType := field.IndirectFieldType.Field(i).Type
|
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
||||||
|
for i := 0; i < rv.Type().NumField(); i++ {
|
||||||
|
newFieldType := rv.Type().Field(i).Type
|
||||||
for newFieldType.Kind() == reflect.Ptr {
|
for newFieldType.Kind() == reflect.Ptr {
|
||||||
newFieldType = newFieldType.Elem()
|
newFieldType = newFieldType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldValue = reflect.New(newFieldType)
|
fieldValue = reflect.New(newFieldType)
|
||||||
overrideFieldValue = true
|
|
||||||
|
if rv.Type() != reflect.Indirect(fieldValue).Type() {
|
||||||
|
getRealFieldValue(fieldValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fieldValue.IsValid() {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy tag settings from valuer
|
|
||||||
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
|
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
|
||||||
if _, ok := field.TagSettings[key]; !ok {
|
if _, ok := field.TagSettings[key]; !ok {
|
||||||
field.TagSettings[key] = value
|
field.TagSettings[key] = value
|
||||||
@ -134,6 +125,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getRealFieldValue(fieldValue)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbName, ok := field.TagSettings["COLUMN"]; ok {
|
if dbName, ok := field.TagSettings["COLUMN"]; ok {
|
||||||
@ -167,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
field.Precision, _ = strconv.Atoi(p)
|
field.Precision, _ = strconv.Atoi(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s, ok := field.TagSettings["SCALE"]; ok {
|
||||||
|
field.Scale, _ = strconv.Atoi(s)
|
||||||
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
||||||
field.NotNull = true
|
field.NotNull = true
|
||||||
}
|
}
|
||||||
@ -213,6 +211,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
|
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
||||||
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
}
|
}
|
||||||
@ -231,6 +231,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||||
if strings.ToUpper(v) == "NANO" {
|
if strings.ToUpper(v) == "NANO" {
|
||||||
field.AutoCreateTime = UnixNanosecond
|
field.AutoCreateTime = UnixNanosecond
|
||||||
|
} else if strings.ToUpper(v) == "MILLI" {
|
||||||
|
field.AutoCreateTime = UnixMillisecond
|
||||||
} else {
|
} else {
|
||||||
field.AutoCreateTime = UnixSecond
|
field.AutoCreateTime = UnixSecond
|
||||||
}
|
}
|
||||||
@ -239,6 +241,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||||
if strings.ToUpper(v) == "NANO" {
|
if strings.ToUpper(v) == "NANO" {
|
||||||
field.AutoUpdateTime = UnixNanosecond
|
field.AutoUpdateTime = UnixNanosecond
|
||||||
|
} else if strings.ToUpper(v) == "MILLI" {
|
||||||
|
field.AutoUpdateTime = UnixMillisecond
|
||||||
} else {
|
} else {
|
||||||
field.AutoUpdateTime = UnixSecond
|
field.AutoUpdateTime = UnixSecond
|
||||||
}
|
}
|
||||||
@ -303,16 +307,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) {
|
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) {
|
||||||
|
if reflect.Indirect(fieldValue).Kind() == reflect.Struct {
|
||||||
var err error
|
var err error
|
||||||
field.Creatable = false
|
field.Creatable = false
|
||||||
field.Updatable = false
|
field.Updatable = false
|
||||||
field.Readable = false
|
field.Readable = false
|
||||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
|
||||||
|
cacheStore := &sync.Map{}
|
||||||
|
cacheStore.Store(embeddedCacheKey, true)
|
||||||
|
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ef := range field.EmbeddedSchema.Fields {
|
for _, ef := range field.EmbeddedSchema.Fields {
|
||||||
ef.Schema = schema
|
ef.Schema = schema
|
||||||
|
ef.OwnerSchema = field.EmbeddedSchema
|
||||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||||
// index is negative means is pointer
|
// index is negative means is pointer
|
||||||
if field.FieldType.Kind() == reflect.Struct {
|
if field.FieldType.Kind() == reflect.Struct {
|
||||||
@ -331,17 +341,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
ef.PrimaryKey = true
|
ef.PrimaryKey = true
|
||||||
} else {
|
} else {
|
||||||
ef.PrimaryKey = false
|
ef.PrimaryKey = false
|
||||||
|
|
||||||
|
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
||||||
|
ef.AutoIncrement = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if ef.DefaultValue == "" {
|
||||||
|
ef.HasDefaultValue = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range field.TagSettings {
|
for k, v := range field.TagSettings {
|
||||||
ef.TagSettings[k] = v
|
ef.TagSettings[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...)
|
schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
|
||||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...)
|
}
|
||||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...)
|
|
||||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return field
|
return field
|
||||||
@ -462,16 +478,16 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
if reflectV.Kind() == reflect.Ptr {
|
||||||
if v, err = valuer.Value(); err == nil {
|
|
||||||
err = setter(value, v)
|
|
||||||
}
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
if reflectV.IsNil() {
|
if reflectV.IsNil() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else {
|
} else {
|
||||||
err = setter(value, reflectV.Elem().Interface())
|
err = setter(value, reflectV.Elem().Interface())
|
||||||
}
|
}
|
||||||
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
if v, err = valuer.Value(); err == nil {
|
||||||
|
err = setter(value, v)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||||
}
|
}
|
||||||
@ -545,6 +561,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case time.Time:
|
case time.Time:
|
||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(value).SetInt(data.UnixNano())
|
field.ReflectValueOf(value).SetInt(data.UnixNano())
|
||||||
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
|
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).SetInt(data.Unix())
|
field.ReflectValueOf(value).SetInt(data.Unix())
|
||||||
}
|
}
|
||||||
@ -552,6 +570,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if data != nil {
|
if data != nil {
|
||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(value).SetInt(data.UnixNano())
|
field.ReflectValueOf(value).SetInt(data.UnixNano())
|
||||||
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
|
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).SetInt(data.Unix())
|
field.ReflectValueOf(value).SetInt(data.Unix())
|
||||||
}
|
}
|
||||||
@ -714,38 +734,9 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||||
// struct scanner
|
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
v, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
reflectV := reflect.ValueOf(v)
|
|
||||||
if !reflectV.IsValid() {
|
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
|
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else {
|
|
||||||
return field.Set(value, reflectV.Elem().Interface())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
|
||||||
// pointer scanner
|
// pointer scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
if valuer == nil {
|
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else {
|
|
||||||
v, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if reflectV.Type().AssignableTo(field.FieldType) {
|
if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV)
|
field.ReflectValueOf(value).Set(reflectV)
|
||||||
@ -760,10 +751,38 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
|
// struct scanner
|
||||||
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
|
reflectV := reflect.ValueOf(v)
|
||||||
|
if !reflectV.IsValid() {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
|
field.ReflectValueOf(value).Set(reflectV)
|
||||||
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
|
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else {
|
||||||
|
return field.Set(value, reflectV.Elem().Interface())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
|
@ -1,23 +1,25 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import "gorm.io/gorm/clause"
|
import (
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
type GormDataTypeInterface interface {
|
type GormDataTypeInterface interface {
|
||||||
GormDataType() string
|
GormDataType() string
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateClausesInterface interface {
|
type CreateClausesInterface interface {
|
||||||
CreateClauses() []clause.Interface
|
CreateClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryClausesInterface interface {
|
type QueryClausesInterface interface {
|
||||||
QueryClauses() []clause.Interface
|
QueryClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateClausesInterface interface {
|
type UpdateClausesInterface interface {
|
||||||
UpdateClauses() []clause.Interface
|
UpdateClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeleteClausesInterface interface {
|
type DeleteClausesInterface interface {
|
||||||
DeleteClauses() []clause.Interface
|
DeleteClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
@ -39,3 +39,24 @@ type AdvancedDataTypeUser struct {
|
|||||||
Active mybool
|
Active mybool
|
||||||
Admin *mybool
|
Admin *mybool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BaseModel struct {
|
||||||
|
ID uint
|
||||||
|
CreatedAt time.Time
|
||||||
|
CreatedBy *int
|
||||||
|
Created *VersionUser `gorm:"foreignKey:CreatedBy"`
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VersionModel struct {
|
||||||
|
BaseModel
|
||||||
|
Version int
|
||||||
|
}
|
||||||
|
|
||||||
|
type VersionUser struct {
|
||||||
|
VersionModel
|
||||||
|
Name string
|
||||||
|
Age uint
|
||||||
|
Birthday *time.Time
|
||||||
|
}
|
||||||
|
@ -66,7 +66,12 @@ func (schema *Schema) parseRelation(field *Field) {
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
cacheStore := schema.cacheStore
|
||||||
|
if field.OwnerSchema != nil {
|
||||||
|
cacheStore = field.OwnerSchema.cacheStore
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -77,8 +82,10 @@ func (schema *Schema) parseRelation(field *Field) {
|
|||||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||||
} else {
|
} else {
|
||||||
switch field.IndirectFieldType.Kind() {
|
switch field.IndirectFieldType.Kind() {
|
||||||
case reflect.Struct, reflect.Slice:
|
case reflect.Struct:
|
||||||
schema.guessRelation(relation, field, true)
|
schema.guessRelation(relation, field, guessBelongs)
|
||||||
|
case reflect.Slice:
|
||||||
|
schema.guessRelation(relation, field, guessHas)
|
||||||
default:
|
default:
|
||||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
||||||
}
|
}
|
||||||
@ -220,7 +227,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
PkgPath: ownField.StructField.PkgPath,
|
PkgPath: ownField.StructField.PkgPath,
|
||||||
Type: ownField.StructField.Type,
|
Type: ownField.StructField.Type,
|
||||||
Tag: removeSettingFromTag(ownField.StructField.Tag, "column"),
|
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,7 +250,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
PkgPath: relField.StructField.PkgPath,
|
PkgPath: relField.StructField.PkgPath,
|
||||||
Type: relField.StructField.Type,
|
Type: relField.StructField.Type,
|
||||||
Tag: removeSettingFromTag(relField.StructField.Tag, "column"),
|
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,21 +323,52 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
|
type guessLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
guessBelongs guessLevel = iota
|
||||||
|
guessEmbeddedBelongs
|
||||||
|
guessHas
|
||||||
|
guessEmbeddedHas
|
||||||
|
)
|
||||||
|
|
||||||
|
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
|
||||||
var (
|
var (
|
||||||
primaryFields, foreignFields []*Field
|
primaryFields, foreignFields []*Field
|
||||||
primarySchema, foreignSchema = schema, relation.FieldSchema
|
primarySchema, foreignSchema = schema, relation.FieldSchema
|
||||||
)
|
)
|
||||||
|
|
||||||
if !guessHas {
|
reguessOrErr := func() {
|
||||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
switch gl {
|
||||||
|
case guessBelongs:
|
||||||
|
schema.guessRelation(relation, field, guessEmbeddedBelongs)
|
||||||
|
case guessEmbeddedBelongs:
|
||||||
|
schema.guessRelation(relation, field, guessHas)
|
||||||
|
case guessHas:
|
||||||
|
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||||
|
// case guessEmbeddedHas:
|
||||||
|
default:
|
||||||
|
schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reguessOrErr := func(err string, args ...interface{}) {
|
switch gl {
|
||||||
if guessHas {
|
case guessBelongs:
|
||||||
schema.guessRelation(relation, field, false)
|
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||||
|
case guessEmbeddedBelongs:
|
||||||
|
if field.OwnerSchema != nil {
|
||||||
|
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||||
} else {
|
} else {
|
||||||
schema.err = fmt.Errorf(err, args...)
|
reguessOrErr()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case guessHas:
|
||||||
|
case guessEmbeddedHas:
|
||||||
|
if field.OwnerSchema != nil {
|
||||||
|
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||||
|
} else {
|
||||||
|
reguessOrErr()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,14 +377,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
||||||
foreignFields = append(foreignFields, f)
|
foreignFields = append(foreignFields, f)
|
||||||
} else {
|
} else {
|
||||||
reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys)
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, primaryField := range primarySchema.PrimaryFields {
|
for _, primaryField := range primarySchema.PrimaryFields {
|
||||||
lookUpName := schema.Name + primaryField.Name
|
lookUpName := primarySchema.Name + primaryField.Name
|
||||||
if !guessHas {
|
if gl == guessBelongs {
|
||||||
lookUpName = field.Name + primaryField.Name
|
lookUpName = field.Name + primaryField.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -358,7 +396,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(foreignFields) == 0 {
|
if len(foreignFields) == 0 {
|
||||||
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
} else if len(relation.primaryKeys) > 0 {
|
} else if len(relation.primaryKeys) > 0 {
|
||||||
for idx, primaryKey := range relation.primaryKeys {
|
for idx, primaryKey := range relation.primaryKeys {
|
||||||
@ -366,11 +404,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||||||
if len(primaryFields) < idx+1 {
|
if len(primaryFields) < idx+1 {
|
||||||
primaryFields = append(primaryFields, f)
|
primaryFields = append(primaryFields, f)
|
||||||
} else if f != primaryFields[idx] {
|
} else if f != primaryFields[idx] {
|
||||||
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -380,7 +418,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||||
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
|
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
|
||||||
} else {
|
} else {
|
||||||
reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -394,11 +432,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||||||
relation.References = append(relation.References, &Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: primaryFields[idx],
|
PrimaryKey: primaryFields[idx],
|
||||||
ForeignKey: foreignField,
|
ForeignKey: foreignField,
|
||||||
OwnPrimaryKey: schema == primarySchema && guessHas,
|
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if guessHas {
|
if gl == guessHas || gl == guessEmbeddedHas {
|
||||||
relation.Type = "has"
|
relation.Type = "has"
|
||||||
} else {
|
} else {
|
||||||
relation.Type = BelongsTo
|
relation.Type = BelongsTo
|
||||||
|
@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
|
||||||
|
type User struct {
|
||||||
|
ID int32 `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
CreatedBy *int32
|
||||||
|
Creator *User `gorm:"foreignKey:CreatedBy;references:ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
checkStructRelation(t, &User{}, Relation{
|
||||||
|
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
|
||||||
|
References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasOneOverrideForeignKey(t *testing.T) {
|
func TestHasOneOverrideForeignKey(t *testing.T) {
|
||||||
type Profile struct {
|
type Profile struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
@ -72,6 +72,10 @@ type Tabler interface {
|
|||||||
|
|
||||||
// get data type from dialector
|
// get data type from dialector
|
||||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||||
|
if dest == nil {
|
||||||
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||||
|
}
|
||||||
|
|
||||||
modelType := reflect.ValueOf(dest).Type()
|
modelType := reflect.ValueOf(dest).Type()
|
||||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
@ -184,11 +188,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||||||
if field := schema.PrioritizedPrimaryField; field != nil {
|
if field := schema.PrioritizedPrimaryField; field != nil {
|
||||||
switch field.GORMDataType {
|
switch field.GORMDataType {
|
||||||
case Int, Uint:
|
case Int, Uint:
|
||||||
|
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
|
||||||
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
|
|
||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
field.AutoIncrement = true
|
field.AutoIncrement = true
|
||||||
}
|
}
|
||||||
@ -208,13 +212,31 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
||||||
// parse relations for unidentified fields
|
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.DataType == "" && field.Creatable {
|
if field.DataType == "" && field.Creatable {
|
||||||
if schema.parseRelation(field); schema.err != nil {
|
if schema.parseRelation(field); schema.err != nil {
|
||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fieldValue := reflect.New(field.IndirectFieldType)
|
||||||
|
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
|
||||||
|
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
|
||||||
|
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
|
||||||
|
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
|
||||||
|
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,10 +49,15 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
|
parsedField, ok := s.FieldsByDBName[f.DBName]
|
||||||
|
if !ok {
|
||||||
|
parsedField, ok = s.FieldsByName[f.Name]
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
||||||
} else {
|
} else {
|
||||||
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
|
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings")
|
||||||
|
|
||||||
if f.DBName != "" {
|
if f.DBName != "" {
|
||||||
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
||||||
@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
|||||||
|
|
||||||
for _, name := range []string{f.DBName, f.Name} {
|
for _, name := range []string{f.DBName, f.Name} {
|
||||||
if name != "" {
|
if name != "" {
|
||||||
if field := s.LookUpField(name); field == nil || parsedField != field {
|
if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) {
|
||||||
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"gorm.io/gorm/utils/tests"
|
"gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
@ -160,3 +161,65 @@ func TestCustomizeTableName(t *testing.T) {
|
|||||||
t.Errorf("Failed to customize table with TableName method")
|
t.Errorf("Failed to customize table with TableName method")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNestedModel(t *testing.T) {
|
||||||
|
versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse nested user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := []schema.Field{
|
||||||
|
{Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true},
|
||||||
|
{Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64},
|
||||||
|
{Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
checkSchemaField(t, versionUser, &f, func(f *schema.Field) {
|
||||||
|
f.Creatable = true
|
||||||
|
f.Updatable = true
|
||||||
|
f.Readable = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedStruct(t *testing.T) {
|
||||||
|
type CorpBase struct {
|
||||||
|
gorm.Model
|
||||||
|
OwnerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Company struct {
|
||||||
|
ID int
|
||||||
|
OwnerID int
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Corp struct {
|
||||||
|
CorpBase
|
||||||
|
Base Company `gorm:"embedded;embeddedPrefix:company_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := []schema.Field{
|
||||||
|
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||||
|
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||||
|
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
||||||
|
f.Creatable = true
|
||||||
|
f.Updatable = true
|
||||||
|
f.Readable = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var embeddedCacheKey = "embedded_cache_store"
|
||||||
|
|
||||||
func ParseTagSetting(str string, sep string) map[string]string {
|
func ParseTagSetting(str string, sep string) map[string]string {
|
||||||
settings := map[string]string{}
|
settings := map[string]string{}
|
||||||
names := strings.Split(str, sep)
|
names := strings.Split(str, sep)
|
||||||
@ -49,8 +51,11 @@ func toColumns(val string) (results []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag {
|
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
|
||||||
return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}"))
|
for _, name := range names {
|
||||||
|
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
|
||||||
|
}
|
||||||
|
return tag
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRelationsValues get relations's values from a reflect value
|
// GetRelationsValues get relations's values from a reflect value
|
||||||
|
@ -13,6 +13,7 @@ func TestRemoveSettingFromTag(t *testing.T) {
|
|||||||
`gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`,
|
`gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`,
|
||||||
`gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`,
|
`gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`,
|
||||||
`gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`,
|
`gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`,
|
||||||
|
`gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`,
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range tags {
|
for k, v := range tags {
|
||||||
|
@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) {
|
|||||||
return n.Time, nil
|
return n.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) QueryClauses() []clause.Interface {
|
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||||
return []clause.Interface{
|
return []clause.Interface{
|
||||||
clause.Where{Exprs: []clause.Expression{
|
clause.Where{Exprs: []clause.Expression{
|
||||||
clause.Eq{
|
clause.Eq{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"},
|
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
|
||||||
Value: nil,
|
Value: nil,
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) DeleteClauses() []clause.Interface {
|
type SoftDeleteQueryClause struct {
|
||||||
return []clause.Interface{SoftDeleteClause{}}
|
Field *schema.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
type SoftDeleteClause struct {
|
func (sd SoftDeleteQueryClause) Name() string {
|
||||||
}
|
|
||||||
|
|
||||||
func (SoftDeleteClause) Name() string {
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (SoftDeleteClause) Build(clause.Builder) {
|
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (SoftDeleteClause) MergeClause(*clause.Clause) {
|
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||||
|
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||||
|
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||||
|
}})
|
||||||
|
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||||
|
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SoftDeleteDeleteClause struct {
|
||||||
|
Field *schema.Field
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd SoftDeleteDeleteClause) Name() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||||
if stmt.SQL.String() == "" {
|
if stmt.SQL.String() == "" {
|
||||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}})
|
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}})
|
||||||
|
|
||||||
if stmt.Schema != nil {
|
if stmt.Schema != nil {
|
||||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||||
@ -64,7 +88,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
|||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.Dest != stmt.Model && stmt.Model != nil {
|
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
@ -74,11 +98,6 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := stmt.Clauses["WHERE"]; !ok {
|
|
||||||
stmt.DB.AddError(ErrMissingWhereClause)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt.AddClauseIfNotExists(clause.Update{})
|
stmt.AddClauseIfNotExists(clause.Update{})
|
||||||
stmt.Build("UPDATE", "SET", "WHERE")
|
stmt.Build("UPDATE", "SET", "WHERE")
|
||||||
}
|
}
|
||||||
|
33
statement.go
33
statement.go
@ -29,7 +29,7 @@ type Statement struct {
|
|||||||
Distinct bool
|
Distinct bool
|
||||||
Selects []string // selected columns
|
Selects []string // selected columns
|
||||||
Omits []string // omit columns
|
Omits []string // omit columns
|
||||||
Joins map[string][]interface{}
|
Joins []join
|
||||||
Preloads map[string][]interface{}
|
Preloads map[string][]interface{}
|
||||||
Settings sync.Map
|
Settings sync.Map
|
||||||
ConnPool ConnPool
|
ConnPool ConnPool
|
||||||
@ -44,6 +44,11 @@ type Statement struct {
|
|||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type join struct {
|
||||||
|
Name string
|
||||||
|
Conds []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
// StatementModifier statement modifier interface
|
// StatementModifier statement modifier interface
|
||||||
type StatementModifier interface {
|
type StatementModifier interface {
|
||||||
ModifyStatement(*Statement)
|
ModifyStatement(*Statement)
|
||||||
@ -95,7 +100,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if v.Name == clause.PrimaryKey {
|
if v.Name == clause.PrimaryKey {
|
||||||
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
|
if stmt.Schema == nil {
|
||||||
|
stmt.DB.AddError(ErrModelValueRequired)
|
||||||
|
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
||||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
|
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||||
} else if len(stmt.Schema.DBNames) > 0 {
|
} else if len(stmt.Schema.DBNames) > 0 {
|
||||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
|
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
|
||||||
@ -154,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||||||
stmt.Vars = append(stmt.Vars, v.Value)
|
stmt.Vars = append(stmt.Vars, v.Value)
|
||||||
case clause.Column, clause.Table:
|
case clause.Column, clause.Table:
|
||||||
stmt.QuoteTo(writer, v)
|
stmt.QuoteTo(writer, v)
|
||||||
|
case Valuer:
|
||||||
|
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||||
case clause.Expr:
|
case clause.Expr:
|
||||||
var varStr strings.Builder
|
var varStr strings.Builder
|
||||||
var sql = v.SQL
|
var sql = v.SQL
|
||||||
@ -307,10 +316,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
|||||||
for _, field := range s.Fields {
|
for _, field := range s.Fields {
|
||||||
if field.Readable {
|
if field.Readable {
|
||||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||||
if field.DBName == "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
|
||||||
} else {
|
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
||||||
|
} else if field.DataType != "" {
|
||||||
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -320,10 +329,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
|||||||
for _, field := range s.Fields {
|
for _, field := range s.Fields {
|
||||||
if field.Readable {
|
if field.Readable {
|
||||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||||
if field.DBName == "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
|
||||||
} else {
|
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
||||||
|
} else if field.DataType != "" {
|
||||||
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -372,7 +381,6 @@ func (stmt *Statement) Build(clauses ...string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO handle named vars
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
@ -390,6 +398,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
|||||||
|
|
||||||
func (stmt *Statement) clone() *Statement {
|
func (stmt *Statement) clone() *Statement {
|
||||||
newStmt := &Statement{
|
newStmt := &Statement{
|
||||||
|
TableExpr: stmt.TableExpr,
|
||||||
Table: stmt.Table,
|
Table: stmt.Table,
|
||||||
Model: stmt.Model,
|
Model: stmt.Model,
|
||||||
Dest: stmt.Dest,
|
Dest: stmt.Dest,
|
||||||
@ -398,7 +407,6 @@ func (stmt *Statement) clone() *Statement {
|
|||||||
Distinct: stmt.Distinct,
|
Distinct: stmt.Distinct,
|
||||||
Selects: stmt.Selects,
|
Selects: stmt.Selects,
|
||||||
Omits: stmt.Omits,
|
Omits: stmt.Omits,
|
||||||
Joins: map[string][]interface{}{},
|
|
||||||
Preloads: map[string][]interface{}{},
|
Preloads: map[string][]interface{}{},
|
||||||
ConnPool: stmt.ConnPool,
|
ConnPool: stmt.ConnPool,
|
||||||
Schema: stmt.Schema,
|
Schema: stmt.Schema,
|
||||||
@ -414,8 +422,9 @@ func (stmt *Statement) clone() *Statement {
|
|||||||
newStmt.Preloads[k] = p
|
newStmt.Preloads[k] = p
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, j := range stmt.Joins {
|
if len(stmt.Joins) > 0 {
|
||||||
newStmt.Joins[k] = j
|
newStmt.Joins = make([]join, len(stmt.Joins))
|
||||||
|
copy(newStmt.Joins, stmt.Joins)
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||||
|
1
tests/.gitignore
vendored
Normal file
1
tests/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
go.sum
|
@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) {
|
|||||||
DB.Model(&user2).Association("Pets").Find(&user2.Pets)
|
DB.Model(&user2).Association("Pets").Find(&user2.Pets)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
var pets []Pet
|
||||||
|
DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets)
|
||||||
|
|
||||||
|
if len(pets) != 1 {
|
||||||
|
t.Fatalf("should only find one pets, but got %v", len(pets))
|
||||||
|
}
|
||||||
|
|
||||||
|
CheckPet(t, pets[0], *user.Pets[0])
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 {
|
||||||
|
t.Fatalf("should only find one pets, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 {
|
||||||
|
t.Fatalf("should only find no pet with invalid conditions, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
// Count
|
// Count
|
||||||
AssertAssociationCount(t, user, "Pets", 2, "")
|
AssertAssociationCount(t, user, "Pets", 2, "")
|
||||||
|
|
||||||
@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) {
|
|||||||
|
|
||||||
AssertAssociationCount(t, user, "Pets", 3, "AfterAppend")
|
AssertAssociationCount(t, user, "Pets", 3, "AfterAppend")
|
||||||
|
|
||||||
var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
|
var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
|
||||||
|
|
||||||
if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil {
|
if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil {
|
||||||
t.Fatalf("Error happened when append pet, got %v", err)
|
t.Fatalf("Error happened when append pet, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pet := range pets {
|
for _, pet := range pets2 {
|
||||||
var pet = pet
|
var pet = pet
|
||||||
if pet.ID == 0 {
|
if pet.ID == 0 {
|
||||||
t.Fatalf("Pet's ID should be created")
|
t.Fatalf("Pet's ID should be created")
|
||||||
|
@ -96,6 +96,18 @@ func TestCallbacks(t *testing.T) {
|
|||||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||||
results: []string{"c1", "c4", "c3"},
|
results: []string{"c1", "c4", "c3"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
|
||||||
|
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
|
||||||
|
results: []string{"c3", "c5", "c1", "c2", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
|
||||||
|
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, data := range datas {
|
for idx, data := range datas {
|
||||||
|
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -55,4 +56,20 @@ func TestCount(t *testing.T) {
|
|||||||
if count3 != 2 {
|
if count3 != 2 {
|
||||||
t.Errorf("Should get correct count for count with group, but got %v", count3)
|
t.Errorf("Should get correct count for count with group, but got %v", count3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
result := dryDB.Table("users").Select("name").Count(&count)
|
||||||
|
if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
result = dryDB.Table("users").Distinct("name").Count(&count)
|
||||||
|
if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var count4 int64
|
||||||
|
if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
||||||
|
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,45 @@ func TestCreate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateFromMap(t *testing.T) {
|
||||||
|
if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result1 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
datas := []map[string]interface{}{
|
||||||
|
{"Name": "create_from_map_2", "Age": 19},
|
||||||
|
{"name": "create_from_map_3", "Age": 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result2 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 {
|
||||||
|
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result3 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 {
|
||||||
|
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateWithAssociations(t *testing.T) {
|
func TestCreateWithAssociations(t *testing.T) {
|
||||||
var user = *GetUser("create_with_associations", Config{
|
var user = *GetUser("create_with_associations", Config{
|
||||||
Account: true,
|
Account: true,
|
||||||
@ -352,3 +391,22 @@ func TestOmitWithCreate(t *testing.T) {
|
|||||||
|
|
||||||
CheckUser(t, result2, user2)
|
CheckUser(t, result2, user2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirstOrCreateWithPrimaryKey(t *testing.T) {
|
||||||
|
company := Company{ID: 100, Name: "company100_with_primarykey"}
|
||||||
|
DB.FirstOrCreate(&company)
|
||||||
|
|
||||||
|
if company.ID != 100 {
|
||||||
|
t.Errorf("invalid primary key after creating, got %v", company.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
companies := []Company{
|
||||||
|
{ID: 101, Name: "company101_with_primarykey"},
|
||||||
|
{ID: 102, Name: "company102_with_primarykey"},
|
||||||
|
}
|
||||||
|
DB.Create(&companies)
|
||||||
|
|
||||||
|
if companies[0].ID != 101 || companies[1].ID != 102 {
|
||||||
|
t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -70,8 +70,10 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
FieldReadonly string `gorm:"->"`
|
FieldReadonly string `gorm:"->"`
|
||||||
FieldIgnore string `gorm:"-"`
|
FieldIgnore string `gorm:"-"`
|
||||||
AutoUnixCreateTime int64 `gorm:"autocreatetime"`
|
AutoUnixCreateTime int64 `gorm:"autocreatetime"`
|
||||||
|
AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"`
|
||||||
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
|
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
|
||||||
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"`
|
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"`
|
||||||
|
AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"`
|
||||||
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"`
|
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,6 +120,10 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
t.Fatalf("invalid create/update unix time: %#v", result)
|
t.Fatalf("invalid create/update unix time: %#v", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 {
|
||||||
|
t.Fatalf("invalid create/update unix milli time: %#v", result)
|
||||||
|
}
|
||||||
|
|
||||||
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 {
|
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 {
|
||||||
t.Fatalf("invalid create/update unix nano time: %#v", result)
|
t.Fatalf("invalid create/update unix nano time: %#v", result)
|
||||||
}
|
}
|
||||||
@ -163,6 +169,8 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
createWithDefaultTime := generateStruct("create_with_default_time")
|
createWithDefaultTime := generateStruct("create_with_default_time")
|
||||||
createWithDefaultTime.AutoUnixCreateTime = 100
|
createWithDefaultTime.AutoUnixCreateTime = 100
|
||||||
createWithDefaultTime.AutoUnixUpdateTime = 100
|
createWithDefaultTime.AutoUnixUpdateTime = 100
|
||||||
|
createWithDefaultTime.AutoUnixMilliCreateTime = 100
|
||||||
|
createWithDefaultTime.AutoUnixMilliUpdateTime = 100
|
||||||
createWithDefaultTime.AutoUnixNanoCreateTime = 100
|
createWithDefaultTime.AutoUnixNanoCreateTime = 100
|
||||||
createWithDefaultTime.AutoUnixNanoUpdateTime = 100
|
createWithDefaultTime.AutoUnixNanoUpdateTime = 100
|
||||||
DB.Create(&createWithDefaultTime)
|
DB.Create(&createWithDefaultTime)
|
||||||
@ -174,6 +182,10 @@ func TestCustomizeField(t *testing.T) {
|
|||||||
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
|
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
|
||||||
|
t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult)
|
||||||
|
}
|
||||||
|
|
||||||
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
|
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
|
||||||
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
|
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) {
|
|||||||
type Harumph struct {
|
type Harumph struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Email string `gorm:"not null;index:,unique"`
|
Email string `gorm:"not null;index:,unique"`
|
||||||
Name string `gorm:"not null;default:'foo'"`
|
Name string `gorm:"not null;default:foo"`
|
||||||
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
||||||
Name3 string `gorm:"size:233;not null;default:''"`
|
Name3 string `gorm:"size:233;not null;default:''"`
|
||||||
Age int `gorm:"default:18"`
|
Age int `gorm:"default:18"`
|
||||||
|
@ -43,6 +43,18 @@ func TestDelete(t *testing.T) {
|
|||||||
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
|
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := DB.Delete(users[0]).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when delete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause {
|
||||||
|
t.Errorf("errors happened when delete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
t.Errorf("should returns record not found error, but got %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteWithTable(t *testing.T) {
|
func TestDeleteWithTable(t *testing.T) {
|
||||||
@ -110,4 +122,8 @@ func TestBlockGlobalDelete(t *testing.T) {
|
|||||||
if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
|
if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
|
||||||
t.Errorf("should returns missing WHERE clause while deleting error")
|
t.Errorf("should returns missing WHERE clause while deleting error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil {
|
||||||
|
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmbeddedStruct(t *testing.T) {
|
func TestEmbeddedStruct(t *testing.T) {
|
||||||
@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
|||||||
t.Errorf("Failed to create got error %v", err)
|
t.Errorf("Failed to create got error %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedRelations(t *testing.T) {
|
||||||
|
type AdvancedUser struct {
|
||||||
|
User `gorm:"embedded"`
|
||||||
|
Advanced bool
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&AdvancedUser{})
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
|
||||||
|
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
10
tests/go.mod
10
tests/go.mod
@ -6,11 +6,11 @@ require (
|
|||||||
github.com/google/uuid v1.1.1
|
github.com/google/uuid v1.1.1
|
||||||
github.com/jinzhu/now v1.1.1
|
github.com/jinzhu/now v1.1.1
|
||||||
github.com/lib/pq v1.6.0
|
github.com/lib/pq v1.6.0
|
||||||
gorm.io/driver/mysql v0.2.9
|
gorm.io/driver/mysql v1.0.0
|
||||||
gorm.io/driver/postgres v0.2.5
|
gorm.io/driver/postgres v1.0.0
|
||||||
gorm.io/driver/sqlite v1.0.8
|
gorm.io/driver/sqlite v1.1.0
|
||||||
gorm.io/driver/sqlserver v0.2.4
|
gorm.io/driver/sqlserver v1.0.0
|
||||||
gorm.io/gorm v0.2.19
|
gorm.io/gorm v1.9.19
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) {
|
|||||||
if db5.Error != nil {
|
if db5.Error != nil {
|
||||||
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
|
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("joins should be ordered, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinsWithSelect(t *testing.T) {
|
func TestJoinsWithSelect(t *testing.T) {
|
||||||
|
@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSmartMigrateColumn(t *testing.T) {
|
||||||
|
type UserMigrateColumn struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Salary float64
|
||||||
|
Birthday time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&UserMigrateColumn{})
|
||||||
|
|
||||||
|
DB.AutoMigrate(&UserMigrateColumn{})
|
||||||
|
|
||||||
|
type UserMigrateColumn2 struct {
|
||||||
|
ID uint
|
||||||
|
Name string `gorm:"size:128"`
|
||||||
|
Salary float64 `gorm:"precision:2"`
|
||||||
|
Birthday time.Time `gorm:"precision:2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
|
||||||
|
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get column types, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
switch columnType.Name() {
|
||||||
|
case "name":
|
||||||
|
if length, _ := columnType.Length(); length != 0 && length != 128 {
|
||||||
|
t.Fatalf("name's length should be 128, but got %v", length)
|
||||||
|
}
|
||||||
|
case "salary":
|
||||||
|
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||||
|
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
|
||||||
|
}
|
||||||
|
case "birthday":
|
||||||
|
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||||
|
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserMigrateColumn3 struct {
|
||||||
|
ID uint
|
||||||
|
Name string `gorm:"size:256"`
|
||||||
|
Salary float64 `gorm:"precision:3"`
|
||||||
|
Birthday time.Time `gorm:"precision:3"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
|
||||||
|
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get column types, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
switch columnType.Name() {
|
||||||
|
case "name":
|
||||||
|
if length, _ := columnType.Length(); length != 0 && length != 256 {
|
||||||
|
t.Fatalf("name's length should be 128, but got %v", length)
|
||||||
|
}
|
||||||
|
case "salary":
|
||||||
|
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||||
|
t.Fatalf("salary's precision should be 2, but got %v", precision)
|
||||||
|
}
|
||||||
|
case "birthday":
|
||||||
|
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||||
|
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestMigrateWithComment(t *testing.T) {
|
func TestMigrateWithComment(t *testing.T) {
|
||||||
type UserWithComment struct {
|
type UserWithComment struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
@ -37,3 +37,40 @@ func TestPostgres(t *testing.T) {
|
|||||||
t.Errorf("No error should happen, but got %v", err)
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
||||||
|
Title string
|
||||||
|
Categories []*Category `gorm:"Many2Many:post_categories"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
||||||
|
Title string
|
||||||
|
Posts []*Post `gorm:"Many2Many:post_categories"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories")
|
||||||
|
DB.AutoMigrate(&Post{}, &Category{})
|
||||||
|
|
||||||
|
post := Post{
|
||||||
|
Title: "Hello World",
|
||||||
|
Categories: []*Category{
|
||||||
|
{Title: "Coding"},
|
||||||
|
{Title: "Golang"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&post).Error; err != nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CheckUser(t, users2[0], users[0])
|
CheckUser(t, users2[0], users[0])
|
||||||
|
|
||||||
|
var users3 []User
|
||||||
|
if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Table("accounts AS a").Select("a.*")
|
||||||
|
}).Find(&users3, "id IN ?", userIDs).Error; err != nil {
|
||||||
|
t.Errorf("failed to query, got error %v", err)
|
||||||
|
}
|
||||||
|
sort.Slice(users3, func(i, j int) bool {
|
||||||
|
return users2[i].ID < users2[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, u := range users3 {
|
||||||
|
CheckUser(t, u, users[i])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNestedPreloadWithConds(t *testing.T) {
|
func TestNestedPreloadWithConds(t *testing.T) {
|
||||||
|
@ -103,6 +103,24 @@ func TestFind(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryWithAssociation(t *testing.T) {
|
||||||
|
user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3})
|
||||||
|
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when create user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user.CreatedAt = time.Time{}
|
||||||
|
user.UpdatedAt = time.Time{}
|
||||||
|
if err := DB.Where(&user).First(&User{}).Error; err != nil {
|
||||||
|
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where(user).First(&User{}).Error; err != nil {
|
||||||
|
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindInBatches(t *testing.T) {
|
func TestFindInBatches(t *testing.T) {
|
||||||
var users = []User{
|
var users = []User{
|
||||||
*GetUser("find_in_batches", Config{}),
|
*GetUser("find_in_batches", Config{}),
|
||||||
@ -328,6 +346,11 @@ func TestSelect(t *testing.T) {
|
|||||||
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
|
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
|
||||||
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
|
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID)
|
||||||
|
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOmit(t *testing.T) {
|
func TestOmit(t *testing.T) {
|
||||||
|
@ -1,16 +1,20 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,6 +31,7 @@ func TestScannerValuer(t *testing.T) {
|
|||||||
Male: sql.NullBool{Bool: true, Valid: true},
|
Male: sql.NullBool{Bool: true, Valid: true},
|
||||||
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
|
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
|
||||||
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
||||||
|
Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}},
|
||||||
Password: EncryptedData("pass1"),
|
Password: EncryptedData("pass1"),
|
||||||
Bytes: []byte("byte"),
|
Bytes: []byte("byte"),
|
||||||
Num: 18,
|
Num: 18,
|
||||||
@ -36,6 +41,8 @@ func TestScannerValuer(t *testing.T) {
|
|||||||
{"name2", "value2"},
|
{"name2", "value2"},
|
||||||
},
|
},
|
||||||
Role: Role{Name: "admin"},
|
Role: Role{Name: "admin"},
|
||||||
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||||
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&data).Error; err != nil {
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
@ -44,10 +51,17 @@ func TestScannerValuer(t *testing.T) {
|
|||||||
|
|
||||||
var result ScannerValuerStruct
|
var result ScannerValuerStruct
|
||||||
|
|
||||||
if err := DB.Find(&result).Error; err != nil {
|
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
|
||||||
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.ExampleStructPtr.Val != "value2" {
|
||||||
|
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ExampleStruct.Val != "value1" {
|
||||||
|
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
|
||||||
|
}
|
||||||
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,6 +75,8 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
|||||||
Name: sql.NullString{String: "name", Valid: true},
|
Name: sql.NullString{String: "name", Valid: true},
|
||||||
Gender: &sql.NullString{String: "M", Valid: true},
|
Gender: &sql.NullString{String: "M", Valid: true},
|
||||||
Age: sql.NullInt64{Int64: 18, Valid: true},
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
||||||
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||||
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
var result ScannerValuerStruct
|
var result ScannerValuerStruct
|
||||||
@ -100,6 +116,8 @@ func TestInvalidValuer(t *testing.T) {
|
|||||||
|
|
||||||
data := ScannerValuerStruct{
|
data := ScannerValuerStruct{
|
||||||
Password: EncryptedData("xpass1"),
|
Password: EncryptedData("xpass1"),
|
||||||
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||||
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&data).Error; err == nil {
|
if err := DB.Create(&data).Error; err == nil {
|
||||||
@ -130,12 +148,18 @@ type ScannerValuerStruct struct {
|
|||||||
Male sql.NullBool
|
Male sql.NullBool
|
||||||
Height sql.NullFloat64
|
Height sql.NullFloat64
|
||||||
Birthday sql.NullTime
|
Birthday sql.NullTime
|
||||||
|
Allergen NullString
|
||||||
Password EncryptedData
|
Password EncryptedData
|
||||||
Bytes []byte
|
Bytes []byte
|
||||||
Num Num
|
Num Num
|
||||||
Strings StringsSlice
|
Strings StringsSlice
|
||||||
Structs StructsSlice
|
Structs StructsSlice
|
||||||
Role Role
|
Role Role
|
||||||
|
UserID *sql.NullInt64
|
||||||
|
User User
|
||||||
|
EmptyTime EmptyTime
|
||||||
|
ExampleStruct ExampleStruct
|
||||||
|
ExampleStructPtr *ExampleStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
@ -201,7 +225,31 @@ func (l *StringsSlice) Scan(input interface{}) error {
|
|||||||
|
|
||||||
type ExampleStruct struct {
|
type ExampleStruct struct {
|
||||||
Name string
|
Name string
|
||||||
Value string
|
Val string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ExampleStruct) GormDataType() string {
|
||||||
|
return "bytes"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ExampleStruct) Value() (driver.Value, error) {
|
||||||
|
if len(s.Name) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// for test, has no practical meaning
|
||||||
|
s.Name = ""
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ExampleStruct) Scan(src interface{}) error {
|
||||||
|
switch value := src.(type) {
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(value), s)
|
||||||
|
case []byte:
|
||||||
|
return json.Unmarshal(value, s)
|
||||||
|
default:
|
||||||
|
return errors.New("not supported")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type StructsSlice []ExampleStruct
|
type StructsSlice []ExampleStruct
|
||||||
@ -242,3 +290,104 @@ func (role Role) Value() (driver.Value, error) {
|
|||||||
func (role Role) IsAdmin() bool {
|
func (role Role) IsAdmin() bool {
|
||||||
return role.Name == "admin"
|
return role.Name == "admin"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmptyTime struct {
|
||||||
|
time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *EmptyTime) Scan(v interface{}) error {
|
||||||
|
nullTime := sql.NullTime{}
|
||||||
|
err := nullTime.Scan(v)
|
||||||
|
t.Time = nullTime.Time
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t EmptyTime) Value() (driver.Value, error) {
|
||||||
|
return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullString struct {
|
||||||
|
sql.NullString
|
||||||
|
}
|
||||||
|
|
||||||
|
type Point struct {
|
||||||
|
X, Y int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (point Point) GormDataType() string {
|
||||||
|
return "geo"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
||||||
|
return clause.Expr{
|
||||||
|
SQL: "ST_PointFromText(?)",
|
||||||
|
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGORMValuer(t *testing.T) {
|
||||||
|
type UserWithPoint struct {
|
||||||
|
Name string
|
||||||
|
Point Point
|
||||||
|
}
|
||||||
|
|
||||||
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
|
||||||
|
stmt := dryRunDB.Create(&UserWithPoint{
|
||||||
|
Name: "jinzhu",
|
||||||
|
Point: Point{X: 100, Y: 100},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{
|
||||||
|
"Name": "jinzhu",
|
||||||
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{
|
||||||
|
"Name": "jinzhu",
|
||||||
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Session(&gorm.Session{
|
||||||
|
AllowGlobalUpdate: true,
|
||||||
|
}).Model(&UserWithPoint{}).Updates(UserWithPoint{
|
||||||
|
Name: "jinzhu",
|
||||||
|
Point: Point{X: 100, Y: 100},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -40,6 +40,17 @@ func TestTable(t *testing.T) {
|
|||||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r = dryDB.Create(&UserWithTable{}).Statement
|
||||||
|
if DB.Dialector.Name() != "sqlite" {
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) {
|
||||||
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) {
|
||||||
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement
|
r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement
|
||||||
if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
|
if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
|
||||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
|
@ -64,6 +64,8 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||||||
// USE gorm;
|
// USE gorm;
|
||||||
// CREATE USER gorm FROM LOGIN gorm;
|
// CREATE USER gorm FROM LOGIN gorm;
|
||||||
// sp_changedbowner 'gorm';
|
// sp_changedbowner 'gorm';
|
||||||
|
// npm install -g sql-cli
|
||||||
|
// mssql -u gorm -p LoremIpsum86 -d gorm -o 9930
|
||||||
log.Println("testing sqlserver...")
|
log.Println("testing sqlserver...")
|
||||||
if dbDSN == "" {
|
if dbDSN == "" {
|
||||||
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCancelTransaction(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancelFunc := context.WithCancel(ctx)
|
||||||
|
cancelFunc()
|
||||||
|
|
||||||
|
user := *GetUser("cancel_transaction", Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
var result User
|
||||||
|
tx.First(&result, user.ID)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Transaction should get error when using cancelled context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransactionWithBlock(t *testing.T) {
|
func TestTransactionWithBlock(t *testing.T) {
|
||||||
assertPanic := func(f func()) {
|
assertPanic := func(f func()) {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) {
|
|||||||
if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
|
if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
|
||||||
t.Errorf("should returns missing WHERE clause while updating error, got err %v", err)
|
t.Errorf("should returns missing WHERE clause while updating error, got err %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil {
|
||||||
|
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectWithUpdate(t *testing.T) {
|
func TestSelectWithUpdate(t *testing.T) {
|
||||||
@ -334,6 +338,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) {
|
|||||||
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
|
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithUpdateWithInvalidMap(t *testing.T) {
|
||||||
|
user := *GetUser("update_with_invalid_map", Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) {
|
||||||
|
t.Errorf("should returns error for unsupported updating data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOmitWithUpdate(t *testing.T) {
|
func TestOmitWithUpdate(t *testing.T) {
|
||||||
user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||||
DB.Create(&user)
|
DB.Create(&user)
|
||||||
@ -545,3 +558,31 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) {
|
|||||||
t.Errorf("element's ignored field should not be updated")
|
t.Errorf("element's ignored field should not be updated")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateFromSubQuery(t *testing.T) {
|
||||||
|
user := *GetUser("update_from_sub_query", Config{Company: true})
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Errorf("failed to create user, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil {
|
||||||
|
t.Errorf("failed to update with sub query, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
DB.First(&result, user.ID)
|
||||||
|
|
||||||
|
if result.Name != user.Company.Name {
|
||||||
|
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&user.Company).Update("Name", "new company name")
|
||||||
|
if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil {
|
||||||
|
t.Errorf("failed to update with sub query, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&result, user.ID)
|
||||||
|
if result.Name != "new company name" {
|
||||||
|
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
name := reflect.ValueOf(got).Type().Elem().Name()
|
name := reflect.ValueOf(got).Type().Elem().Name()
|
||||||
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
|
t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -29,8 +29,8 @@ func FileWithLineNum() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsChar(c rune) bool {
|
func IsValidDBNameChar(c rune) bool {
|
||||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*'
|
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckTruth(val interface{}) bool {
|
func CheckTruth(val interface{}) bool {
|
||||||
|
14
utils/utils_test.go
Normal file
14
utils/utils_test.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsValidDBNameChar(t *testing.T) {
|
||||||
|
for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} {
|
||||||
|
if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 {
|
||||||
|
t.Fatalf("failed to parse db name %v", db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user