Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
68640733b6
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
sqlite:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.14', '1.13']
|
||||
go: ['1.15', '1.14', '1.13']
|
||||
platform: [ubuntu-latest, macos-latest] # can not run in windows OS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -38,7 +38,7 @@ jobs:
|
||||
sqlite_windows:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.14', '1.13']
|
||||
go: ['1.15', '1.14', '1.13']
|
||||
platform: [windows-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -64,7 +64,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
|
||||
go: ['1.14', '1.13']
|
||||
go: ['1.15', '1.14', '1.13']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -108,7 +108,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
|
||||
go: ['1.14', '1.13']
|
||||
go: ['1.15', '1.14', '1.13']
|
||||
platform: [ubuntu-latest] # can not run in macOS and widnowsOS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -150,7 +150,7 @@ jobs:
|
||||
sqlserver:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.14', '1.13']
|
||||
go: ['1.15', '1.14', '1.13']
|
||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
|
@ -19,11 +19,11 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||
* Context, Prepared Statment Mode, DryRun Mode
|
||||
* Batch Insert, FindInBatches, Find To Map
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||
* Composite Primary Key
|
||||
* Auto Migrations
|
||||
* Logger
|
||||
* Extendable, write Plugins based on GORM callbacks
|
||||
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
|
||||
* Every feature comes with tests
|
||||
* Developer Friendly
|
||||
|
||||
@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
||||
|
||||
|
@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association {
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
tx = association.DB.Model(out)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE", "LIMIT")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
association.Error = tx.Find(out, conds...).Error
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error {
|
||||
association.Error = association.Replace(values...)
|
||||
}
|
||||
default:
|
||||
association.saveAssociation(false, values...)
|
||||
association.saveAssociation( /*clear*/ false, values...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error {
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
// save associations
|
||||
association.saveAssociation(true, values...)
|
||||
association.saveAssociation( /*clear*/ true, values...)
|
||||
|
||||
// set old associations's foreign key to null
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
@ -112,7 +88,7 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
}
|
||||
|
||||
association.DB.UpdateColumns(updateMap)
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
// clean up deleted values's foreign key
|
||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||
|
||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||
@ -328,33 +305,8 @@ func (association *Association) Clear() error {
|
||||
|
||||
func (association *Association) Count() (count int64) {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE", "LIMIT")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: conds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: conds})
|
||||
}
|
||||
|
||||
association.Error = tx.Count(&count).Error
|
||||
association.Error = association.buildCondition().Count(&count).Error
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(values) != reflectValue.Len() {
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||
@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
||||
}
|
||||
case reflect.Struct:
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
|
||||
@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) buildCondition() *DB {
|
||||
var (
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
||||
objs = append(objs, obj)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, rv)
|
||||
} else {
|
||||
elems = reflect.Append(elems, rv.Addr())
|
||||
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
||||
objs = append(objs, obj)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, rv)
|
||||
} else {
|
||||
elems = reflect.Append(elems, rv.Addr())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
|
||||
if _, zero := rel.Field.ValueOf(obj); !zero {
|
||||
rv := rel.Field.ReflectValueOf(obj)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(obj)
|
||||
db.AddError(ref.ForeignKey.Set(rv, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(obj); !zero {
|
||||
rv := rel.Field.ReflectValueOf(obj)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
}
|
||||
|
||||
elems = reflect.Append(elems, rv)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(obj)
|
||||
db.AddError(ref.ForeignKey.Set(rv, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
elems = reflect.Append(elems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
appendToElems(db.Statement.ReflectValue.Index(i))
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
appendToElems(db.Statement.ReflectValue.Index(i))
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
|
@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(gorm.BeforeCreateInterface); ok {
|
||||
if i, ok := value.(BeforeCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i))
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
||||
if isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID--
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID++
|
||||
}
|
||||
}
|
||||
@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) {
|
||||
for rows.Next() {
|
||||
BEGIN:
|
||||
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
|
||||
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
for idx, field := range fields {
|
||||
fieldValue := field.ReflectValueOf(reflectValue)
|
||||
|
||||
@ -189,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(gorm.AfterCreateInterface); ok {
|
||||
if i, ok := value.(AfterCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
@ -211,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
case *map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||
case []map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
case *[]map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||
default:
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
@ -295,6 +313,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
@ -41,7 +41,7 @@ func Delete(db *gorm.DB) {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
@ -51,15 +51,15 @@ func Delete(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
db.Statement.Build("DELETE", "FROM", "WHERE")
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
@ -20,12 +20,18 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
|
||||
|
||||
for _, k := range keys {
|
||||
value := mapValue[k]
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||
if len(values.Values) == 0 {
|
||||
values.Values = [][]interface{}{{}}
|
||||
}
|
||||
|
||||
values.Values[0] = append(values.Values[0], value)
|
||||
}
|
||||
}
|
||||
@ -42,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
for k, v := range mapValue {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := result[k]; !ok {
|
||||
@ -61,11 +69,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
||||
|
||||
sort.Strings(columns)
|
||||
values.Values = make([][]interface{}, len(mapValues))
|
||||
values.Columns = make([]clause.Column, len(columns))
|
||||
for idx, column := range columns {
|
||||
values.Columns[idx] = clause.Column{Name: column}
|
||||
|
||||
for i, v := range result[column] {
|
||||
if i == 0 {
|
||||
if len(values.Values[i]) == 0 {
|
||||
values.Values[i] = make([]interface{}, len(columns))
|
||||
}
|
||||
|
||||
values.Values[i][idx] = v
|
||||
}
|
||||
}
|
||||
|
39
callbacks/interfaces.go
Normal file
39
callbacks/interfaces.go
Normal file
@ -0,0 +1,39 @@
|
||||
package callbacks
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*gorm.DB) error
|
||||
}
|
@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||
|
||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
|
||||
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
|
||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||
@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||
}
|
||||
|
||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues)
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...)
|
||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
||||
|
@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
|
||||
// inline joins
|
||||
if len(db.Statement.Joins) != 0 {
|
||||
if len(db.Statement.Selects) == 0 {
|
||||
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
joins := []clause.Join{}
|
||||
for name, conds := range db.Statement.Joins {
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema == nil {
|
||||
joins = append(joins, clause.Join{
|
||||
Expression: clause.Expr{SQL: name, Vars: conds},
|
||||
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
tableAliasName := relation.Name
|
||||
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
})
|
||||
} else {
|
||||
joins = append(joins, clause.Join{
|
||||
Expression: clause.Expr{SQL: name, Vars: conds},
|
||||
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
|
@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
@ -69,7 +69,7 @@ func Update(db *gorm.DB) {
|
||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
||||
if i, ok := value.(AfterUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
kv := value[k]
|
||||
if _, ok := kv.(*gorm.DB); ok {
|
||||
kv = []interface{}{kv}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
if field.DBName != "" {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||
@ -189,12 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
||||
}
|
||||
}
|
||||
|
||||
if !stmt.UpdatingColumn && stmt.Schema != nil {
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
now := stmt.DB.NowFunc()
|
||||
@ -217,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
switch updatingValue.Kind() {
|
||||
case reflect.Struct:
|
||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
value, isZero := field.ValueOf(updatingValue)
|
||||
@ -247,6 +254,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
if len(args) > 0 {
|
||||
tx = tx.Select(args[0], args[1:]...)
|
||||
}
|
||||
return tx
|
||||
return
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
@ -91,8 +91,9 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
}
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
case string:
|
||||
fields := strings.FieldsFunc(v, utils.IsChar)
|
||||
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
|
||||
|
||||
// normal field names
|
||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||
@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
} else {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
@ -130,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
@ -169,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Joins == nil {
|
||||
tx.Statement.Joins = map[string][]interface{}{}
|
||||
}
|
||||
tx.Statement.Joins[query] = args
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||
return
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
fields := strings.FieldsFunc(name, utils.IsChar)
|
||||
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
||||
})
|
||||
|
@ -19,6 +19,8 @@ var (
|
||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||
// ErrRegistered registered
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
@ -132,19 +133,46 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
return
|
||||
}
|
||||
|
||||
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
for _, value := range values {
|
||||
switch v := value.(type) {
|
||||
case []clause.Expression:
|
||||
for _, expr := range v {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Column:
|
||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
case clause.Column:
|
||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||
exprs := tx.Statement.BuildCondition(value)
|
||||
tx.assignInterfacesToValue(exprs)
|
||||
default:
|
||||
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, f := range s.Fields {
|
||||
if f.Readable {
|
||||
if v, isZero := f.ValueOf(reflectValue); !isZero {
|
||||
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
} else if len(values) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
|
||||
tx.assignInterfacesToValue(exprs)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -154,22 +182,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignExprsToValue(where.Exprs)
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
tx.Error = nil
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -180,20 +206,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignExprsToValue(where.Exprs)
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
@ -265,11 +289,14 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model == nil {
|
||||
tx.Statement.Model = tx.Statement.Dest
|
||||
defer func() {
|
||||
tx.Statement.Model = nil
|
||||
}()
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||
defer tx.Statement.AddClause(clause.Select{})
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
||||
expr := clause.Expr{SQL: "count(1)"}
|
||||
|
||||
@ -335,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx.AddError(ErrModelValueRequired)
|
||||
}
|
||||
|
||||
fields := strings.FieldsFunc(column, utils.IsChar)
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
@ -420,7 +447,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
|
||||
// Commit commit a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Commit())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
@ -431,7 +458,9 @@ func (db *DB) Commit() *DB {
|
||||
// Rollback rollback a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
db.AddError(committer.Rollback())
|
||||
if !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Rollback())
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
|
7
gorm.go
7
gorm.go
@ -32,6 +32,8 @@ type Config struct {
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
AllowGlobalUpdate bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
@ -61,6 +63,7 @@ type Session struct {
|
||||
PrepareStmt bool
|
||||
WithConditions bool
|
||||
SkipDefaultTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
@ -154,6 +157,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||
tx.Config.SkipDefaultTransaction = true
|
||||
}
|
||||
|
||||
if config.AllowGlobalUpdate {
|
||||
txConfig.AllowGlobalUpdate = true
|
||||
}
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
|
@ -53,38 +53,7 @@ type TxCommitter interface {
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*DB) error
|
||||
// Valuer gorm valuer interface
|
||||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ type Migrator interface {
|
||||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
|
||||
// AutoMigrate
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
// TODO smart migrate data type
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
if !tx.Migrator().HasColumn(value, field.DBName) {
|
||||
var foundColumn *sql.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() == field.DBName {
|
||||
foundColumn = columnType
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error {
|
||||
// found, smart migrate
|
||||
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
alterColumn := false
|
||||
|
||||
// check size
|
||||
if length, _ := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
|
||||
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
// not primary key & database is nullable
|
||||
if !field.PrimaryKey && nullable {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
if alterColumn {
|
||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
||||
|
147
schema/field.go
147
schema/field.go
@ -55,6 +55,7 @@ type Field struct {
|
||||
Comment string
|
||||
Size int
|
||||
Precision int
|
||||
Scale int
|
||||
FieldType reflect.Type
|
||||
IndirectFieldType reflect.Type
|
||||
StructField reflect.StructField
|
||||
@ -88,53 +89,44 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(field.IndirectFieldType)
|
||||
|
||||
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
|
||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
|
||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
|
||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
|
||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...)
|
||||
}
|
||||
|
||||
// if field is valuer, used its value or first fields as data type
|
||||
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
|
||||
if isValuer {
|
||||
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
|
||||
var overrideFieldValue bool
|
||||
if v, err := valuer.Value(); v != nil && err == nil {
|
||||
overrideFieldValue = true
|
||||
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
|
||||
fieldValue = reflect.ValueOf(v)
|
||||
}
|
||||
|
||||
if field.IndirectFieldType.Kind() == reflect.Struct {
|
||||
for i := 0; i < field.IndirectFieldType.NumField(); i++ {
|
||||
if !overrideFieldValue {
|
||||
newFieldType := field.IndirectFieldType.Field(i).Type
|
||||
var getRealFieldValue func(reflect.Value)
|
||||
getRealFieldValue = func(v reflect.Value) {
|
||||
rv := reflect.Indirect(v)
|
||||
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
||||
for i := 0; i < rv.Type().NumField(); i++ {
|
||||
newFieldType := rv.Type().Field(i).Type
|
||||
for newFieldType.Kind() == reflect.Ptr {
|
||||
newFieldType = newFieldType.Elem()
|
||||
}
|
||||
|
||||
fieldValue = reflect.New(newFieldType)
|
||||
overrideFieldValue = true
|
||||
}
|
||||
|
||||
// copy tag settings from valuer
|
||||
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
|
||||
if _, ok := field.TagSettings[key]; !ok {
|
||||
field.TagSettings[key] = value
|
||||
if rv.Type() != reflect.Indirect(fieldValue).Type() {
|
||||
getRealFieldValue(fieldValue)
|
||||
}
|
||||
|
||||
if fieldValue.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
|
||||
if _, ok := field.TagSettings[key]; !ok {
|
||||
field.TagSettings[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getRealFieldValue(fieldValue)
|
||||
}
|
||||
}
|
||||
|
||||
@ -169,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
field.Precision, _ = strconv.Atoi(p)
|
||||
}
|
||||
|
||||
if s, ok := field.TagSettings["SCALE"]; ok {
|
||||
field.Scale, _ = strconv.Atoi(s)
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
||||
field.NotNull = true
|
||||
}
|
||||
@ -311,15 +307,19 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) {
|
||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) {
|
||||
if reflect.Indirect(fieldValue).Kind() == reflect.Struct {
|
||||
var err error
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
field.Readable = false
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||
|
||||
cacheStore := &sync.Map{}
|
||||
cacheStore.Store(embeddedCacheKey, true)
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
|
||||
for _, ef := range field.EmbeddedSchema.Fields {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
@ -341,17 +341,20 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
ef.PrimaryKey = true
|
||||
} else {
|
||||
ef.PrimaryKey = false
|
||||
|
||||
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
||||
ef.AutoIncrement = false
|
||||
}
|
||||
|
||||
if ef.DefaultValue == "" {
|
||||
ef.HasDefaultValue = false
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range field.TagSettings {
|
||||
ef.TagSettings[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...)
|
||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...)
|
||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...)
|
||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
|
||||
}
|
||||
@ -475,16 +478,16 @@ func (field *Field) setupValuerAndSetter() {
|
||||
}
|
||||
}
|
||||
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = setter(value, v)
|
||||
}
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
err = setter(value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = setter(value, v)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||
}
|
||||
@ -731,40 +734,10 @@ func (field *Field) setupValuerAndSetter() {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
// struct scanner
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else {
|
||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||
// pointer scanner
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
if valuer == nil || reflectV.IsNil() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
}
|
||||
|
||||
if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
@ -778,10 +751,38 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
// struct scanner
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else {
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
|
@ -1,23 +1,25 @@
|
||||
package schema
|
||||
|
||||
import "gorm.io/gorm/clause"
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
}
|
||||
|
||||
type CreateClausesInterface interface {
|
||||
CreateClauses() []clause.Interface
|
||||
CreateClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
type QueryClausesInterface interface {
|
||||
QueryClauses() []clause.Interface
|
||||
QueryClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
type UpdateClausesInterface interface {
|
||||
UpdateClauses() []clause.Interface
|
||||
UpdateClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
type DeleteClausesInterface interface {
|
||||
DeleteClauses() []clause.Interface
|
||||
DeleteClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
@ -39,3 +39,24 @@ type AdvancedDataTypeUser struct {
|
||||
Active mybool
|
||||
Admin *mybool
|
||||
}
|
||||
|
||||
type BaseModel struct {
|
||||
ID uint
|
||||
CreatedAt time.Time
|
||||
CreatedBy *int
|
||||
Created *VersionUser `gorm:"foreignKey:CreatedBy"`
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
type VersionModel struct {
|
||||
BaseModel
|
||||
Version int
|
||||
}
|
||||
|
||||
type VersionUser struct {
|
||||
VersionModel
|
||||
Name string
|
||||
Age uint
|
||||
Birthday *time.Time
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -67,16 +66,14 @@ func (schema *Schema) parseRelation(field *Field) {
|
||||
}
|
||||
)
|
||||
|
||||
cacheStore := schema.cacheStore
|
||||
if field.OwnerSchema != nil {
|
||||
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
cacheStore = field.OwnerSchema.cacheStore
|
||||
}
|
||||
|
||||
if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
@ -85,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else {
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
case reflect.Struct, reflect.Slice:
|
||||
case reflect.Struct:
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
||||
@ -228,7 +227,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
Name: joinFieldName,
|
||||
PkgPath: ownField.StructField.PkgPath,
|
||||
Type: ownField.StructField.Type,
|
||||
Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"),
|
||||
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -251,7 +250,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
Name: joinFieldName,
|
||||
PkgPath: relField.StructField.PkgPath,
|
||||
Type: relField.StructField.Type,
|
||||
Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"),
|
||||
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -327,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
type guessLevel int
|
||||
|
||||
const (
|
||||
guessHas guessLevel = iota
|
||||
guessEmbeddedHas
|
||||
guessBelongs
|
||||
guessBelongs guessLevel = iota
|
||||
guessEmbeddedBelongs
|
||||
guessHas
|
||||
guessEmbeddedHas
|
||||
)
|
||||
|
||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
|
||||
@ -339,34 +338,36 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
||||
primarySchema, foreignSchema = schema, relation.FieldSchema
|
||||
)
|
||||
|
||||
reguessOrErr := func(err string, args ...interface{}) {
|
||||
reguessOrErr := func() {
|
||||
switch gl {
|
||||
case guessHas:
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
case guessEmbeddedHas:
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
case guessBelongs:
|
||||
schema.guessRelation(relation, field, guessEmbeddedBelongs)
|
||||
case guessEmbeddedBelongs:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
case guessHas:
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf(err, args...)
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
switch gl {
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||
return
|
||||
}
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
} else {
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -376,7 +377,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
} else {
|
||||
reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys)
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -395,7 +396,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
||||
}
|
||||
|
||||
if len(foreignFields) == 0 {
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||
reguessOrErr()
|
||||
return
|
||||
} else if len(relation.primaryKeys) > 0 {
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
@ -403,11 +404,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
||||
if len(primaryFields) < idx+1 {
|
||||
primaryFields = append(primaryFields, f)
|
||||
} else if f != primaryFields[idx] {
|
||||
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -417,7 +418,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
|
||||
} else {
|
||||
reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
|
||||
type User struct {
|
||||
ID int32 `gorm:"primaryKey"`
|
||||
Name string
|
||||
CreatedBy *int32
|
||||
Creator *User `gorm:"foreignKey:CreatedBy;references:ID"`
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
|
||||
References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
|
@ -212,11 +212,29 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||
}
|
||||
|
||||
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
||||
// parse relations for unidentified fields
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.Creatable {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.Creatable {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(field.IndirectFieldType)
|
||||
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
|
||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
|
||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
|
||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
|
||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -49,10 +49,15 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
||||
}
|
||||
}
|
||||
|
||||
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
|
||||
parsedField, ok := s.FieldsByDBName[f.DBName]
|
||||
if !ok {
|
||||
parsedField, ok = s.FieldsByName[f.Name]
|
||||
}
|
||||
|
||||
if !ok {
|
||||
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
||||
} else {
|
||||
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
|
||||
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings")
|
||||
|
||||
if f.DBName != "" {
|
||||
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
||||
@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
||||
|
||||
for _, name := range []string{f.DBName, f.Name} {
|
||||
if name != "" {
|
||||
if field := s.LookUpField(name); field == nil || parsedField != field {
|
||||
if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) {
|
||||
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
@ -160,3 +161,65 @@ func TestCustomizeTableName(t *testing.T) {
|
||||
t.Errorf("Failed to customize table with TableName method")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedModel(t *testing.T) {
|
||||
versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse nested user, got error %v", err)
|
||||
}
|
||||
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true},
|
||||
{Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64},
|
||||
{Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, versionUser, &f, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedStruct(t *testing.T) {
|
||||
type CorpBase struct {
|
||||
gorm.Model
|
||||
OwnerID string
|
||||
}
|
||||
|
||||
type Company struct {
|
||||
ID int
|
||||
OwnerID int
|
||||
Name string
|
||||
}
|
||||
|
||||
type Corp struct {
|
||||
CorpBase
|
||||
Base Company `gorm:"embedded;embeddedPrefix:company_"`
|
||||
}
|
||||
|
||||
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
|
||||
}
|
||||
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
var embeddedCacheKey = "embedded_cache_store"
|
||||
|
||||
func ParseTagSetting(str string, sep string) map[string]string {
|
||||
settings := map[string]string{}
|
||||
names := strings.Split(str, sep)
|
||||
@ -49,8 +51,11 @@ func toColumns(val string) (results []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag {
|
||||
return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
|
||||
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
|
||||
for _, name := range names {
|
||||
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
// GetRelationsValues get relations's values from a reflect value
|
||||
|
@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) {
|
||||
return n.Time, nil
|
||||
}
|
||||
|
||||
func (DeletedAt) QueryClauses() []clause.Interface {
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{
|
||||
clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"},
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
|
||||
Value: nil,
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses() []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteClause{}}
|
||||
type SoftDeleteQueryClause struct {
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
type SoftDeleteClause struct {
|
||||
}
|
||||
|
||||
func (SoftDeleteClause) Name() string {
|
||||
func (sd SoftDeleteQueryClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (SoftDeleteClause) Build(clause.Builder) {
|
||||
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (SoftDeleteClause) MergeClause(*clause.Clause) {
|
||||
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
|
||||
}
|
||||
|
||||
type SoftDeleteDeleteClause struct {
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.String() == "" {
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}})
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}})
|
||||
|
||||
if stmt.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
@ -64,7 +88,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
@ -74,11 +98,6 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := stmt.Clauses["WHERE"]; !ok {
|
||||
stmt.DB.AddError(ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
stmt.Build("UPDATE", "SET", "WHERE")
|
||||
}
|
||||
|
29
statement.go
29
statement.go
@ -29,7 +29,7 @@ type Statement struct {
|
||||
Distinct bool
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
Joins map[string][]interface{}
|
||||
Joins []join
|
||||
Preloads map[string][]interface{}
|
||||
Settings sync.Map
|
||||
ConnPool ConnPool
|
||||
@ -44,6 +44,11 @@ type Statement struct {
|
||||
assigns []interface{}
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Conds []interface{}
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
type StatementModifier interface {
|
||||
ModifyStatement(*Statement)
|
||||
@ -156,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
stmt.Vars = append(stmt.Vars, v.Value)
|
||||
case clause.Column, clause.Table:
|
||||
stmt.QuoteTo(writer, v)
|
||||
case Valuer:
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
case clause.Expr:
|
||||
var varStr strings.Builder
|
||||
var sql = v.SQL
|
||||
@ -309,10 +316,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
for _, field := range s.Fields {
|
||||
if field.Readable {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
if field.DBName == "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||
} else {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -322,10 +329,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
for _, field := range s.Fields {
|
||||
if field.Readable {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||
if field.DBName == "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||
} else {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -374,7 +381,6 @@ func (stmt *Statement) Build(clauses ...string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO handle named vars
|
||||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
@ -392,6 +398,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
|
||||
func (stmt *Statement) clone() *Statement {
|
||||
newStmt := &Statement{
|
||||
TableExpr: stmt.TableExpr,
|
||||
Table: stmt.Table,
|
||||
Model: stmt.Model,
|
||||
Dest: stmt.Dest,
|
||||
@ -400,7 +407,6 @@ func (stmt *Statement) clone() *Statement {
|
||||
Distinct: stmt.Distinct,
|
||||
Selects: stmt.Selects,
|
||||
Omits: stmt.Omits,
|
||||
Joins: map[string][]interface{}{},
|
||||
Preloads: map[string][]interface{}{},
|
||||
ConnPool: stmt.ConnPool,
|
||||
Schema: stmt.Schema,
|
||||
@ -416,8 +422,9 @@ func (stmt *Statement) clone() *Statement {
|
||||
newStmt.Preloads[k] = p
|
||||
}
|
||||
|
||||
for k, j := range stmt.Joins {
|
||||
newStmt.Joins[k] = j
|
||||
if len(stmt.Joins) > 0 {
|
||||
newStmt.Joins = make([]join, len(stmt.Joins))
|
||||
copy(newStmt.Joins, stmt.Joins)
|
||||
}
|
||||
|
||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||
|
@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) {
|
||||
DB.Model(&user2).Association("Pets").Find(&user2.Pets)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
var pets []Pet
|
||||
DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets)
|
||||
|
||||
if len(pets) != 1 {
|
||||
t.Fatalf("should only find one pets, but got %v", len(pets))
|
||||
}
|
||||
|
||||
CheckPet(t, pets[0], *user.Pets[0])
|
||||
|
||||
if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 {
|
||||
t.Fatalf("should only find one pets, but got %v", count)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 {
|
||||
t.Fatalf("should only find no pet with invalid conditions, but got %v", count)
|
||||
}
|
||||
|
||||
// Count
|
||||
AssertAssociationCount(t, user, "Pets", 2, "")
|
||||
|
||||
@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) {
|
||||
|
||||
AssertAssociationCount(t, user, "Pets", 3, "AfterAppend")
|
||||
|
||||
var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
|
||||
var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
|
||||
|
||||
if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil {
|
||||
if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil {
|
||||
t.Fatalf("Error happened when append pet, got %v", err)
|
||||
}
|
||||
|
||||
for _, pet := range pets {
|
||||
for _, pet := range pets2 {
|
||||
var pet = pet
|
||||
if pet.ID == 0 {
|
||||
t.Fatalf("Pet's ID should be created")
|
||||
|
@ -67,4 +67,9 @@ func TestCount(t *testing.T) {
|
||||
if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
var count4 int64
|
||||
if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
||||
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||
}
|
||||
}
|
||||
|
@ -39,6 +39,45 @@ func TestCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromMap(t *testing.T) {
|
||||
if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
var result1 User
|
||||
if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
datas := []map[string]interface{}{
|
||||
{"Name": "create_from_map_2", "Age": 19},
|
||||
{"name": "create_from_map_3", "Age": 20},
|
||||
}
|
||||
|
||||
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
|
||||
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||
}
|
||||
|
||||
var result2 User
|
||||
if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
var result3 User
|
||||
if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithAssociations(t *testing.T) {
|
||||
var user = *GetUser("create_with_associations", Config{
|
||||
Account: true,
|
||||
|
@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) {
|
||||
type Harumph struct {
|
||||
gorm.Model
|
||||
Email string `gorm:"not null;index:,unique"`
|
||||
Name string `gorm:"not null;default:'foo'"`
|
||||
Name string `gorm:"not null;default:foo"`
|
||||
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
||||
Name3 string `gorm:"size:233;not null;default:''"`
|
||||
Age int `gorm:"default:18"`
|
||||
|
@ -43,6 +43,18 @@ func TestDelete(t *testing.T) {
|
||||
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := DB.Delete(users[0]).Error; err != nil {
|
||||
t.Errorf("errors happened when delete: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause {
|
||||
t.Errorf("errors happened when delete: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Errorf("should returns record not found error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteWithTable(t *testing.T) {
|
||||
@ -110,4 +122,8 @@ func TestBlockGlobalDelete(t *testing.T) {
|
||||
if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
|
||||
t.Errorf("should returns missing WHERE clause while deleting error")
|
||||
}
|
||||
|
||||
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil {
|
||||
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) {
|
||||
Advanced bool
|
||||
}
|
||||
|
||||
DB.Debug().Migrator().DropTable(&AdvancedUser{})
|
||||
DB.Migrator().DropTable(&AdvancedUser{})
|
||||
|
||||
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
|
||||
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
|
||||
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
10
tests/go.mod
10
tests/go.mod
@ -6,11 +6,11 @@ require (
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/lib/pq v1.6.0
|
||||
gorm.io/driver/mysql v0.3.1
|
||||
gorm.io/driver/postgres v0.2.6
|
||||
gorm.io/driver/sqlite v1.0.8
|
||||
gorm.io/driver/sqlserver v0.2.5
|
||||
gorm.io/gorm v0.2.19
|
||||
gorm.io/driver/mysql v1.0.0
|
||||
gorm.io/driver/postgres v1.0.0
|
||||
gorm.io/driver/sqlite v1.1.0
|
||||
gorm.io/driver/sqlserver v1.0.0
|
||||
gorm.io/gorm v1.9.19
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) {
|
||||
if db5.Error != nil {
|
||||
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
|
||||
}
|
||||
|
||||
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement
|
||||
|
||||
if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) {
|
||||
t.Errorf("joins should be ordered, but got %v", stmt.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinsWithSelect(t *testing.T) {
|
||||
|
@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmartMigrateColumn(t *testing.T) {
|
||||
type UserMigrateColumn struct {
|
||||
ID uint
|
||||
Name string
|
||||
Salary float64
|
||||
Birthday time.Time
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&UserMigrateColumn{})
|
||||
|
||||
DB.AutoMigrate(&UserMigrateColumn{})
|
||||
|
||||
type UserMigrateColumn2 struct {
|
||||
ID uint
|
||||
Name string `gorm:"size:128"`
|
||||
Salary float64 `gorm:"precision:2"`
|
||||
Birthday time.Time `gorm:"precision:2"`
|
||||
}
|
||||
|
||||
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); length != 0 && length != 128 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type UserMigrateColumn3 struct {
|
||||
ID uint
|
||||
Name string `gorm:"size:256"`
|
||||
Salary float64 `gorm:"precision:3"`
|
||||
Birthday time.Time `gorm:"precision:3"`
|
||||
}
|
||||
|
||||
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); length != 0 && length != 256 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v", precision)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMigrateWithComment(t *testing.T) {
|
||||
type UserWithComment struct {
|
||||
gorm.Model
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) {
|
||||
}
|
||||
|
||||
CheckUser(t, users2[0], users[0])
|
||||
|
||||
var users3 []User
|
||||
if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Table("accounts AS a").Select("a.*")
|
||||
}).Find(&users3, "id IN ?", userIDs).Error; err != nil {
|
||||
t.Errorf("failed to query, got error %v", err)
|
||||
}
|
||||
sort.Slice(users3, func(i, j int) bool {
|
||||
return users2[i].ID < users2[j].ID
|
||||
})
|
||||
|
||||
for i, u := range users3 {
|
||||
CheckUser(t, u, users[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreloadWithConds(t *testing.T) {
|
||||
|
@ -103,6 +103,24 @@ func TestFind(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryWithAssociation(t *testing.T) {
|
||||
user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3})
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create user: %v", err)
|
||||
}
|
||||
|
||||
user.CreatedAt = time.Time{}
|
||||
user.UpdatedAt = time.Time{}
|
||||
if err := DB.Where(&user).First(&User{}).Error; err != nil {
|
||||
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where(user).First(&User{}).Error; err != nil {
|
||||
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindInBatches(t *testing.T) {
|
||||
var users = []User{
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
@ -328,6 +346,11 @@ func TestSelect(t *testing.T) {
|
||||
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
|
||||
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
|
||||
}
|
||||
|
||||
r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID)
|
||||
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
|
||||
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmit(t *testing.T) {
|
||||
|
@ -1,16 +1,20 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -27,6 +31,7 @@ func TestScannerValuer(t *testing.T) {
|
||||
Male: sql.NullBool{Bool: true, Valid: true},
|
||||
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
|
||||
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}},
|
||||
Password: EncryptedData("pass1"),
|
||||
Bytes: []byte("byte"),
|
||||
Num: 18,
|
||||
@ -35,7 +40,9 @@ func TestScannerValuer(t *testing.T) {
|
||||
{"name1", "value1"},
|
||||
{"name2", "value2"},
|
||||
},
|
||||
Role: Role{Name: "admin"},
|
||||
Role: Role{Name: "admin"},
|
||||
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
@ -44,10 +51,17 @@ func TestScannerValuer(t *testing.T) {
|
||||
|
||||
var result ScannerValuerStruct
|
||||
|
||||
if err := DB.Find(&result).Error; err != nil {
|
||||
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
|
||||
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
||||
}
|
||||
|
||||
if result.ExampleStructPtr.Val != "value2" {
|
||||
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
|
||||
}
|
||||
|
||||
if result.ExampleStruct.Val != "value1" {
|
||||
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
|
||||
}
|
||||
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
||||
}
|
||||
|
||||
@ -58,9 +72,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
||||
}
|
||||
|
||||
data := ScannerValuerStruct{
|
||||
Name: sql.NullString{String: "name", Valid: true},
|
||||
Gender: &sql.NullString{String: "M", Valid: true},
|
||||
Age: sql.NullInt64{Int64: 18, Valid: true},
|
||||
Name: sql.NullString{String: "name", Valid: true},
|
||||
Gender: &sql.NullString{String: "M", Valid: true},
|
||||
Age: sql.NullInt64{Int64: 18, Valid: true},
|
||||
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||
}
|
||||
|
||||
var result ScannerValuerStruct
|
||||
@ -99,7 +115,9 @@ func TestInvalidValuer(t *testing.T) {
|
||||
}
|
||||
|
||||
data := ScannerValuerStruct{
|
||||
Password: EncryptedData("xpass1"),
|
||||
Password: EncryptedData("xpass1"),
|
||||
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err == nil {
|
||||
@ -124,21 +142,24 @@ func TestInvalidValuer(t *testing.T) {
|
||||
|
||||
type ScannerValuerStruct struct {
|
||||
gorm.Model
|
||||
Name sql.NullString
|
||||
Gender *sql.NullString
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
Birthday sql.NullTime
|
||||
Password EncryptedData
|
||||
Bytes []byte
|
||||
Num Num
|
||||
Strings StringsSlice
|
||||
Structs StructsSlice
|
||||
Role Role
|
||||
UserID *sql.NullInt64
|
||||
User User
|
||||
EmptyTime EmptyTime
|
||||
Name sql.NullString
|
||||
Gender *sql.NullString
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
Birthday sql.NullTime
|
||||
Allergen NullString
|
||||
Password EncryptedData
|
||||
Bytes []byte
|
||||
Num Num
|
||||
Strings StringsSlice
|
||||
Structs StructsSlice
|
||||
Role Role
|
||||
UserID *sql.NullInt64
|
||||
User User
|
||||
EmptyTime EmptyTime
|
||||
ExampleStruct ExampleStruct
|
||||
ExampleStructPtr *ExampleStruct
|
||||
}
|
||||
|
||||
type EncryptedData []byte
|
||||
@ -203,8 +224,32 @@ func (l *StringsSlice) Scan(input interface{}) error {
|
||||
}
|
||||
|
||||
type ExampleStruct struct {
|
||||
Name string
|
||||
Value string
|
||||
Name string
|
||||
Val string
|
||||
}
|
||||
|
||||
func (ExampleStruct) GormDataType() string {
|
||||
return "bytes"
|
||||
}
|
||||
|
||||
func (s ExampleStruct) Value() (driver.Value, error) {
|
||||
if len(s.Name) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// for test, has no practical meaning
|
||||
s.Name = ""
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
func (s *ExampleStruct) Scan(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), s)
|
||||
case []byte:
|
||||
return json.Unmarshal(value, s)
|
||||
default:
|
||||
return errors.New("not supported")
|
||||
}
|
||||
}
|
||||
|
||||
type StructsSlice []ExampleStruct
|
||||
@ -260,3 +305,89 @@ func (t *EmptyTime) Scan(v interface{}) error {
|
||||
func (t EmptyTime) Value() (driver.Value, error) {
|
||||
return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil
|
||||
}
|
||||
|
||||
type NullString struct {
|
||||
sql.NullString
|
||||
}
|
||||
|
||||
type Point struct {
|
||||
X, Y int
|
||||
}
|
||||
|
||||
func (point Point) GormDataType() string {
|
||||
return "geo"
|
||||
}
|
||||
|
||||
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
||||
return clause.Expr{
|
||||
SQL: "ST_PointFromText(?)",
|
||||
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
|
||||
}
|
||||
}
|
||||
|
||||
func TestGORMValuer(t *testing.T) {
|
||||
type UserWithPoint struct {
|
||||
Name string
|
||||
Point Point
|
||||
}
|
||||
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
|
||||
stmt := dryRunDB.Create(&UserWithPoint{
|
||||
Name: "jinzhu",
|
||||
Point: Point{X: 100, Y: 100},
|
||||
}).Statement
|
||||
|
||||
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
|
||||
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||
}
|
||||
|
||||
stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{
|
||||
"Name": "jinzhu",
|
||||
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
||||
}).Statement
|
||||
|
||||
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||
}
|
||||
|
||||
stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{
|
||||
"Name": "jinzhu",
|
||||
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
||||
}).Statement
|
||||
|
||||
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||
}
|
||||
|
||||
stmt = dryRunDB.Session(&gorm.Session{
|
||||
AllowGlobalUpdate: true,
|
||||
}).Model(&UserWithPoint{}).Updates(UserWithPoint{
|
||||
Name: "jinzhu",
|
||||
Point: Point{X: 100, Y: 100},
|
||||
}).Statement
|
||||
|
||||
if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) {
|
||||
t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,17 @@ func TestTable(t *testing.T) {
|
||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||
}
|
||||
|
||||
r = dryDB.Create(&UserWithTable{}).Statement
|
||||
if DB.Dialector.Name() != "sqlite" {
|
||||
if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) {
|
||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||
}
|
||||
} else {
|
||||
if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) {
|
||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement
|
||||
if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
|
||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
cancelFunc()
|
||||
|
||||
user := *GetUser("cancel_transaction", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var result User
|
||||
tx.First(&result, user.ID)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Transaction should get error when using cancelled context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionWithBlock(t *testing.T) {
|
||||
assertPanic := func(f func()) {
|
||||
defer func() {
|
||||
|
@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) {
|
||||
if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
|
||||
t.Errorf("should returns missing WHERE clause while updating error, got err %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil {
|
||||
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithUpdate(t *testing.T) {
|
||||
@ -334,6 +338,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) {
|
||||
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
|
||||
}
|
||||
|
||||
func TestWithUpdateWithInvalidMap(t *testing.T) {
|
||||
user := *GetUser("update_with_invalid_map", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) {
|
||||
t.Errorf("should returns error for unsupported updating data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmitWithUpdate(t *testing.T) {
|
||||
user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
@ -545,3 +558,31 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) {
|
||||
t.Errorf("element's ignored field should not be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromSubQuery(t *testing.T) {
|
||||
user := *GetUser("update_from_sub_query", Config{Company: true})
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Errorf("failed to create user, got error: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil {
|
||||
t.Errorf("failed to update with sub query, got error %v", err)
|
||||
}
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
|
||||
if result.Name != user.Company.Name {
|
||||
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
|
||||
}
|
||||
|
||||
DB.Model(&user.Company).Update("Name", "new company name")
|
||||
if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil {
|
||||
t.Errorf("failed to update with sub query, got error %v", err)
|
||||
}
|
||||
|
||||
DB.First(&result, user.ID)
|
||||
if result.Name != "new company name" {
|
||||
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
||||
}
|
||||
} else {
|
||||
name := reflect.ValueOf(got).Type().Elem().Name()
|
||||
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
|
||||
t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -29,8 +29,8 @@ func FileWithLineNum() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func IsChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*'
|
||||
func IsValidDBNameChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
|
||||
}
|
||||
|
||||
func CheckTruth(val interface{}) bool {
|
||||
|
14
utils/utils_test.go
Normal file
14
utils/utils_test.go
Normal file
@ -0,0 +1,14 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidDBNameChar(t *testing.T) {
|
||||
for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} {
|
||||
if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 {
|
||||
t.Fatalf("failed to parse db name %v", db)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user