Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Steve Fan 2020-08-29 17:30:28 +08:00
commit 68640733b6
47 changed files with 1100 additions and 385 deletions

View File

@ -13,7 +13,7 @@ jobs:
sqlite:
strategy:
matrix:
go: ['1.14', '1.13']
go: ['1.15', '1.14', '1.13']
platform: [ubuntu-latest, macos-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
@ -38,7 +38,7 @@ jobs:
sqlite_windows:
strategy:
matrix:
go: ['1.14', '1.13']
go: ['1.15', '1.14', '1.13']
platform: [windows-latest]
runs-on: ${{ matrix.platform }}
@ -64,7 +64,7 @@ jobs:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
go: ['1.14', '1.13']
go: ['1.15', '1.14', '1.13']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -108,7 +108,7 @@ jobs:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
go: ['1.14', '1.13']
go: ['1.15', '1.14', '1.13']
platform: [ubuntu-latest] # can not run in macOS and widnowsOS
runs-on: ${{ matrix.platform }}
@ -150,7 +150,7 @@ jobs:
sqlserver:
strategy:
matrix:
go: ['1.14', '1.13']
go: ['1.15', '1.14', '1.13']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}

View File

@ -19,11 +19,11 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statment Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Composite Primary Key
* Auto Migrations
* Logger
* Extendable, write Plugins based on GORM callbacks
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
* Every feature comes with tests
* Developer Friendly
@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly.
© Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)

View File

@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association {
func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
tx = association.DB.Model(out)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
association.Error = tx.Find(out, conds...).Error
association.Error = association.buildCondition().Find(out, conds...).Error
}
return association.Error
}
@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error {
association.Error = association.Replace(values...)
}
default:
association.saveAssociation(false, values...)
association.saveAssociation( /*clear*/ false, values...)
}
}
@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil {
// save associations
association.saveAssociation(true, values...)
association.saveAssociation( /*clear*/ true, values...)
// set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue
@ -112,7 +88,7 @@ func (association *Association) Replace(values ...interface{}) error {
updateMap[ref.ForeignKey.DBName] = nil
}
association.DB.UpdateColumns(updateMap)
association.Error = association.DB.UpdateColumns(updateMap).Error
}
case schema.HasOne, schema.HasMany:
var (
@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error {
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
)
for _, ref := range rel.References {
@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error {
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
}
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
@ -328,33 +305,8 @@ func (association *Association) Clear() error {
func (association *Association) Count() (count int64) {
if association.Error == nil {
var (
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: conds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: conds})
}
association.Error = tx.Count(&count).Error
association.Error = association.buildCondition().Count(&count).Error
}
return
}
@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
}
}
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
return tx
}

View File

@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) {
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
}
} else {
break
}
}
@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
}
elems = reflect.Append(elems, rv)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
}
}
elems = reflect.Append(elems, rv)
}
}
}
@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)

View File

@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok {
if i, ok := value.(BeforeCreateInterface); ok {
called = true
db.AddError(i.BeforeCreate(tx))
}
@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i))
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID--
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID++
}
}
@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) {
for rows.Next() {
BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
@ -189,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
@ -211,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, *value)
case []map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
@ -295,6 +313,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}

View File

@ -11,7 +11,7 @@ import (
func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
}
@ -41,7 +41,7 @@ func Delete(db *gorm.DB) {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
@ -51,15 +51,15 @@ func Delete(db *gorm.DB) {
}
}
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE")
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
}

View File

@ -20,12 +20,18 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
for _, k := range keys {
value := mapValue[k]
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]interface{}{{}}
}
values.Values[0] = append(values.Values[0], value)
}
}
@ -42,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if _, ok := result[k]; !ok {
@ -61,11 +69,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
sort.Strings(columns)
values.Values = make([][]interface{}, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] {
if i == 0 {
if len(values.Values[i]) == 0 {
values.Values[i] = make([]interface{}, len(columns))
}
values.Values[i][idx] = v
}
}

39
callbacks/interfaces.go Normal file
View File

@ -0,0 +1,39 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

View File

@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
// convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields))
@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues)
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}
}
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
fieldValues := make([]interface{}, len(relForeignFields))

View File

@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) {
// inline joins
if len(db.Statement.Joins) != 0 {
if len(db.Statement.Selects) == 0 {
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) {
}
joins := []clause.Join{}
for name, conds := range db.Statement.Joins {
for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil {
joins = append(joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds},
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
})
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name
for _, s := range relation.FieldSchema.DBNames {
@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) {
})
} else {
joins = append(joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds},
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
})
}
}
@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.AfterFindInterface); ok {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
}

View File

@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
if i, ok := value.(BeforeUpdateInterface); ok {
called = true
db.AddError(i.BeforeUpdate(tx))
}
@ -69,7 +69,7 @@ func Update(db *gorm.DB) {
db.Statement.Build("UPDATE", "SET", "WHERE")
}
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
sort.Strings(keys)
for _, k := range keys {
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []interface{}{kv}
}
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
@ -189,12 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
}
}
if !stmt.UpdatingColumn && stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByDBName {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
now := stmt.DB.NowFunc()
@ -217,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, field := range stmt.Schema.FieldsByDBName {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
value, isZero := field.ValueOf(updatingValue)
@ -247,6 +254,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}

View File

@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
return tx
return
}
// Select specify fields that you want when querying, creating, updating
@ -91,8 +91,9 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return
}
}
delete(tx.Statement.Clauses, "SELECT")
case string:
fields := strings.FieldsFunc(v, utils.IsChar)
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
// normal field names
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return
}
}
delete(tx.Statement.Clauses, "SELECT")
} else {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
@ -130,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
} else {
tx.Statement.Omits = columns
}
@ -169,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Joins == nil {
tx.Statement.Joins = map[string][]interface{}{}
}
tx.Statement.Joins[query] = args
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
return
}
@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsChar)
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
})

View File

@ -19,6 +19,8 @@ var (
ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered

View File

@ -8,6 +8,7 @@ import (
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
@ -132,19 +133,46 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
return
}
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
}
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
exprs := tx.Statement.BuildCondition(value)
tx.assignInterfacesToValue(exprs)
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
}
}
}
}
}
default:
} else if len(values) > 0 {
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
tx.assignInterfacesToValue(exprs)
return
}
}
}
@ -154,22 +182,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs)
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
tx.Error = nil
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return
}
@ -180,20 +206,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs)
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return tx.Create(dest)
@ -265,11 +289,14 @@ func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
defer func() {
tx.Statement.Model = nil
}()
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
defer tx.Statement.AddClause(clause.Select{})
defer delete(tx.Statement.Clauses, "SELECT")
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
expr := clause.Expr{SQL: "count(1)"}
@ -335,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx.AddError(ErrModelValueRequired)
}
fields := strings.FieldsFunc(column, utils.IsChar)
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
@ -420,7 +447,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
// Commit commit a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
} else {
db.AddError(ErrInvalidTransaction)
@ -431,7 +458,9 @@ func (db *DB) Commit() *DB {
// Rollback rollback a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
db.AddError(committer.Rollback())
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
} else {
db.AddError(ErrInvalidTransaction)
}

View File

@ -32,6 +32,8 @@ type Config struct {
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -61,6 +63,7 @@ type Session struct {
PrepareStmt bool
WithConditions bool
SkipDefaultTransaction bool
AllowGlobalUpdate bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
@ -154,6 +157,10 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.SkipDefaultTransaction = true
}
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.Context != nil {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx

View File

@ -53,38 +53,7 @@ type TxCommitter interface {
Rollback() error
}
type BeforeCreateInterface interface {
BeforeCreate(*DB) error
}
type AfterCreateInterface interface {
AfterCreate(*DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*DB) error
}
type AfterSaveInterface interface {
AfterSave(*DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*DB) error
}
type AfterFindInterface interface {
AfterFind(*DB) error
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}

View File

@ -42,6 +42,7 @@ type Migrator interface {
AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)

View File

@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
"gorm.io/gorm"
@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
// AutoMigrate
func (m Migrator) AutoMigrate(values ...interface{}) error {
// TODO smart migrate data type
for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{})
if !tx.Migrator().HasTable(value) {
@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, field := range stmt.Schema.FieldsByDBName {
if !tx.Migrator().HasColumn(value, field.DBName) {
var foundColumn *sql.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
return err
}
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
}
}
@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
return nil
}); err != nil {
fmt.Println(err)
return err
}
}
@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
})
}
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error {
// found, smart migrate
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
realDataType := strings.ToLower(columnType.DatabaseTypeName())
alterColumn := false
// check size
if length, _ := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
alterColumn = true
}
}
}
// check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) {
alterColumn = true
}
}
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true
}
}
if alterColumn {
return m.DB.Migrator().AlterColumn(value, field.Name)
}
return nil
}
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()

View File

@ -55,6 +55,7 @@ type Field struct {
Comment string
Size int
Precision int
Scale int
FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField
@ -88,53 +89,44 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...)
}
// if field is valuer, used its value or first fields as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
var overrideFieldValue bool
if v, err := valuer.Value(); v != nil && err == nil {
overrideFieldValue = true
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
fieldValue = reflect.ValueOf(v)
}
if field.IndirectFieldType.Kind() == reflect.Struct {
for i := 0; i < field.IndirectFieldType.NumField(); i++ {
if !overrideFieldValue {
newFieldType := field.IndirectFieldType.Field(i).Type
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v)
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem()
}
fieldValue = reflect.New(newFieldType)
overrideFieldValue = true
}
// copy tag settings from valuer
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
if rv.Type() != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
}
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
}
}
getRealFieldValue(fieldValue)
}
}
@ -169,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Precision, _ = strconv.Atoi(p)
}
if s, ok := field.TagSettings["SCALE"]; ok {
field.Scale, _ = strconv.Atoi(s)
}
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
}
@ -311,15 +307,19 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) {
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) {
if reflect.Indirect(fieldValue).Kind() == reflect.Struct {
var err error
field.Creatable = false
field.Updatable = false
field.Readable = false
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
schema.err = err
}
for _, ef := range field.EmbeddedSchema.Fields {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
@ -341,17 +341,20 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.PrimaryKey = true
} else {
ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false
}
if ef.DefaultValue == "" {
ef.HasDefaultValue = false
}
}
for k, v := range field.TagSettings {
ef.TagSettings[k] = v
}
}
field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...)
field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...)
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...)
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...)
} else {
schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
}
@ -475,16 +478,16 @@ func (field *Field) setupValuerAndSetter() {
}
}
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = setter(value, v)
}
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
err = setter(value, reflectV.Elem().Interface())
}
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = setter(value, v)
}
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
@ -731,40 +734,10 @@ func (field *Field) setupValuerAndSetter() {
return nil
}
default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if valuer, ok := v.(driver.Valuer); ok {
if valuer == nil || reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
v, _ = valuer.Value()
}
}
if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
@ -778,10 +751,38 @@ func (field *Field) setupValuerAndSetter() {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = fieldValue.Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Set = func(value reflect.Value, v interface{}) (err error) {
return fallbackSetter(value, v, field.Set)

View File

@ -1,23 +1,25 @@
package schema
import "gorm.io/gorm/clause"
import (
"gorm.io/gorm/clause"
)
type GormDataTypeInterface interface {
GormDataType() string
}
type CreateClausesInterface interface {
CreateClauses() []clause.Interface
CreateClauses(*Field) []clause.Interface
}
type QueryClausesInterface interface {
QueryClauses() []clause.Interface
QueryClauses(*Field) []clause.Interface
}
type UpdateClausesInterface interface {
UpdateClauses() []clause.Interface
UpdateClauses(*Field) []clause.Interface
}
type DeleteClausesInterface interface {
DeleteClauses() []clause.Interface
DeleteClauses(*Field) []clause.Interface
}

View File

@ -39,3 +39,24 @@ type AdvancedDataTypeUser struct {
Active mybool
Admin *mybool
}
type BaseModel struct {
ID uint
CreatedAt time.Time
CreatedBy *int
Created *VersionUser `gorm:"foreignKey:CreatedBy"`
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type VersionModel struct {
BaseModel
Version int
}
type VersionUser struct {
VersionModel
Name string
Age uint
Birthday *time.Time
}

View File

@ -5,7 +5,6 @@ import (
"reflect"
"regexp"
"strings"
"sync"
"github.com/jinzhu/inflection"
"gorm.io/gorm/clause"
@ -67,16 +66,14 @@ func (schema *Schema) parseRelation(field *Field) {
}
)
cacheStore := schema.cacheStore
if field.OwnerSchema != nil {
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
schema.err = err
return
}
} else {
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
}
cacheStore = field.OwnerSchema.cacheStore
}
if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err
return
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
@ -85,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) {
schema.buildMany2ManyRelation(relation, field, many2many)
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice:
case reflect.Struct:
schema.guessRelation(relation, field, guessBelongs)
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
@ -228,7 +227,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type,
Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"),
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
@ -251,7 +250,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"),
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
@ -327,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
type guessLevel int
const (
guessHas guessLevel = iota
guessEmbeddedHas
guessBelongs
guessBelongs guessLevel = iota
guessEmbeddedBelongs
guessHas
guessEmbeddedHas
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
@ -339,34 +338,36 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
primarySchema, foreignSchema = schema, relation.FieldSchema
)
reguessOrErr := func(err string, args ...interface{}) {
reguessOrErr := func() {
switch gl {
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
case guessEmbeddedHas:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
case guessEmbeddedBelongs:
schema.guessRelation(relation, field, guessHas)
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf(err, args...)
schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
}
}
switch gl {
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
}
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
reguessOrErr()
return
}
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr()
return
}
}
@ -376,7 +377,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
if f := foreignSchema.LookUpField(foreignKey); f != nil {
foreignFields = append(foreignFields, f)
} else {
reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys)
reguessOrErr()
return
}
}
@ -395,7 +396,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
}
if len(foreignFields) == 0 {
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
reguessOrErr()
return
} else if len(relation.primaryKeys) > 0 {
for idx, primaryKey := range relation.primaryKeys {
@ -403,11 +404,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
if len(primaryFields) < idx+1 {
primaryFields = append(primaryFields, f)
} else if f != primaryFields[idx] {
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
reguessOrErr()
return
}
} else {
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
reguessOrErr()
return
}
}
@ -417,7 +418,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
} else {
reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name)
reguessOrErr()
return
}
}

View File

@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) {
})
}
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatedBy *int32
Creator *User `gorm:"foreignKey:CreatedBy;references:ID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct {
gorm.Model

View File

@ -212,11 +212,29 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
// parse relations for unidentified fields
for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
}
}
fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
}

View File

@ -49,10 +49,15 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
}
}
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
parsedField, ok := s.FieldsByDBName[f.DBName]
if !ok {
parsedField, ok = s.FieldsByName[f.Name]
}
if !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else {
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings")
if f.DBName != "" {
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
for _, name := range []string{f.DBName, f.Name} {
if name != "" {
if field := s.LookUpField(name); field == nil || parsedField != field {
if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
}

View File

@ -4,6 +4,7 @@ import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
@ -160,3 +161,65 @@ func TestCustomizeTableName(t *testing.T) {
t.Errorf("Failed to customize table with TableName method")
}
}
func TestNestedModel(t *testing.T) {
versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse nested user, got error %v", err)
}
fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true},
{Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64},
{Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64},
}
for _, f := range fields {
checkSchemaField(t, versionUser, &f, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
})
}
}
func TestEmbeddedStruct(t *testing.T) {
type CorpBase struct {
gorm.Model
OwnerID string
}
type Company struct {
ID int
OwnerID int
Name string
}
type Corp struct {
CorpBase
Base Company `gorm:"embedded;embeddedPrefix:company_"`
}
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
}
fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
}
for _, f := range fields {
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
})
}
}

View File

@ -9,6 +9,8 @@ import (
"gorm.io/gorm/utils"
)
var embeddedCacheKey = "embedded_cache_store"
func ParseTagSetting(str string, sep string) map[string]string {
settings := map[string]string{}
names := strings.Split(str, sep)
@ -49,8 +51,11 @@ func toColumns(val string) (results []string) {
return
}
func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag {
return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
for _, name := range names {
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
}
return tag
}
// GetRelationsValues get relations's values from a reflect value

View File

@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) {
return n.Time, nil
}
func (DeletedAt) QueryClauses() []clause.Interface {
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{
clause.Where{Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"},
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
Value: nil,
},
}},
}
}
func (DeletedAt) DeleteClauses() []clause.Interface {
return []clause.Interface{SoftDeleteClause{}}
type SoftDeleteQueryClause struct {
Field *schema.Field
}
type SoftDeleteClause struct {
}
func (SoftDeleteClause) Name() string {
func (sd SoftDeleteQueryClause) Name() string {
return ""
}
func (SoftDeleteClause) Build(clause.Builder) {
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
}
func (SoftDeleteClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
}
func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
}})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
}
}
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
}
type SoftDeleteDeleteClause struct {
Field *schema.Field
}
func (sd SoftDeleteDeleteClause) Name() string {
return ""
}
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
}
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.String() == "" {
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}})
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}})
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
@ -64,7 +88,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if stmt.Dest != stmt.Model && stmt.Model != nil {
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
@ -74,11 +98,6 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
}
}
if _, ok := stmt.Clauses["WHERE"]; !ok {
stmt.DB.AddError(ErrMissingWhereClause)
return
}
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build("UPDATE", "SET", "WHERE")
}

View File

@ -29,7 +29,7 @@ type Statement struct {
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Joins map[string][]interface{}
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
ConnPool ConnPool
@ -44,6 +44,11 @@ type Statement struct {
assigns []interface{}
}
type join struct {
Name string
Conds []interface{}
}
// StatementModifier statement modifier interface
type StatementModifier interface {
ModifyStatement(*Statement)
@ -156,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
case clause.Expr:
var varStr strings.Builder
var sql = v.SQL
@ -309,10 +316,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
for _, field := range s.Fields {
if field.Readable {
if v, isZero := field.ValueOf(reflectValue); !isZero {
if field.DBName == "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
} else {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
}
}
}
@ -322,10 +329,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
for _, field := range s.Fields {
if field.Readable {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
if field.DBName == "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
} else {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
}
}
}
@ -374,7 +381,6 @@ func (stmt *Statement) Build(clauses ...string) {
}
}
}
// TODO handle named vars
}
func (stmt *Statement) Parse(value interface{}) (err error) {
@ -392,6 +398,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
func (stmt *Statement) clone() *Statement {
newStmt := &Statement{
TableExpr: stmt.TableExpr,
Table: stmt.Table,
Model: stmt.Model,
Dest: stmt.Dest,
@ -400,7 +407,6 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct,
Selects: stmt.Selects,
Omits: stmt.Omits,
Joins: map[string][]interface{}{},
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
@ -416,8 +422,9 @@ func (stmt *Statement) clone() *Statement {
newStmt.Preloads[k] = p
}
for k, j := range stmt.Joins {
newStmt.Joins[k] = j
if len(stmt.Joins) > 0 {
newStmt.Joins = make([]join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins)
}
stmt.Settings.Range(func(k, v interface{}) bool {

View File

@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) {
DB.Model(&user2).Association("Pets").Find(&user2.Pets)
CheckUser(t, user2, user)
var pets []Pet
DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets)
if len(pets) != 1 {
t.Fatalf("should only find one pets, but got %v", len(pets))
}
CheckPet(t, pets[0], *user.Pets[0])
if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 {
t.Fatalf("should only find one pets, but got %v", count)
}
if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 {
t.Fatalf("should only find no pet with invalid conditions, but got %v", count)
}
// Count
AssertAssociationCount(t, user, "Pets", 2, "")
@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Pets", 3, "AfterAppend")
var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil {
if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil {
t.Fatalf("Error happened when append pet, got %v", err)
}
for _, pet := range pets {
for _, pet := range pets2 {
var pet = pet
if pet.ID == 0 {
t.Fatalf("Pet's ID should be created")

View File

@ -67,4 +67,9 @@ func TestCount(t *testing.T) {
if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String())
}
var count4 int64
if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count)
}
}

View File

@ -39,6 +39,45 @@ func TestCreate(t *testing.T) {
}
}
func TestCreateFromMap(t *testing.T) {
if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
var result User
if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 {
t.Fatalf("failed to create from map, got error %v", err)
}
if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
var result1 User
if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 {
t.Fatalf("failed to create from map, got error %v", err)
}
datas := []map[string]interface{}{
{"Name": "create_from_map_2", "Age": 19},
{"name": "create_from_map_3", "Age": 20},
}
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}
var result2 User
if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
var result3 User
if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
}
func TestCreateWithAssociations(t *testing.T) {
var user = *GetUser("create_with_associations", Config{
Account: true,

View File

@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) {
type Harumph struct {
gorm.Model
Email string `gorm:"not null;index:,unique"`
Name string `gorm:"not null;default:'foo'"`
Name string `gorm:"not null;default:foo"`
Name2 string `gorm:"size:233;not null;default:'foo'"`
Name3 string `gorm:"size:233;not null;default:''"`
Age int `gorm:"default:18"`

View File

@ -43,6 +43,18 @@ func TestDelete(t *testing.T) {
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
}
}
if err := DB.Delete(users[0]).Error; err != nil {
t.Errorf("errors happened when delete: %v", err)
}
if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause {
t.Errorf("errors happened when delete: %v", err)
}
if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
t.Errorf("should returns record not found error, but got %v", err)
}
}
func TestDeleteWithTable(t *testing.T) {
@ -110,4 +122,8 @@ func TestBlockGlobalDelete(t *testing.T) {
if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
t.Errorf("should returns missing WHERE clause while deleting error")
}
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil {
t.Errorf("should returns no error while enable global update, but got err %v", err)
}
}

View File

@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) {
Advanced bool
}
DB.Debug().Migrator().DropTable(&AdvancedUser{})
DB.Migrator().DropTable(&AdvancedUser{})
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
}
}

View File

@ -6,11 +6,11 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.3.1
gorm.io/driver/postgres v0.2.6
gorm.io/driver/sqlite v1.0.8
gorm.io/driver/sqlserver v0.2.5
gorm.io/gorm v0.2.19
gorm.io/driver/mysql v1.0.0
gorm.io/driver/postgres v1.0.0
gorm.io/driver/sqlite v1.1.0
gorm.io/driver/sqlserver v1.0.0
gorm.io/gorm v1.9.19
)
replace gorm.io/gorm => ../

View File

@ -1,6 +1,7 @@
package tests_test
import (
"regexp"
"sort"
"testing"
@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) {
if db5.Error != nil {
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
}
dryDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement
if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) {
t.Errorf("joins should be ordered, but got %v", stmt.SQL.String())
}
}
func TestJoinsWithSelect(t *testing.T) {

View File

@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) {
}
}
func TestSmartMigrateColumn(t *testing.T) {
type UserMigrateColumn struct {
ID uint
Name string
Salary float64
Birthday time.Time
}
DB.Migrator().DropTable(&UserMigrateColumn{})
DB.AutoMigrate(&UserMigrateColumn{})
type UserMigrateColumn2 struct {
ID uint
Name string `gorm:"size:128"`
Salary float64 `gorm:"precision:2"`
Birthday time.Time `gorm:"precision:2"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); length != 0 && length != 128 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
}
type UserMigrateColumn3 struct {
ID uint
Name string `gorm:"size:256"`
Salary float64 `gorm:"precision:3"`
Birthday time.Time `gorm:"precision:3"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); length != 0 && length != 256 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
t.Fatalf("salary's precision should be 2, but got %v", precision)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
}
}
func TestMigrateWithComment(t *testing.T) {
type UserWithComment struct {
gorm.Model

View File

@ -5,6 +5,7 @@ import (
"strconv"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) {
}
CheckUser(t, users2[0], users[0])
var users3 []User
if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB {
return tx.Table("accounts AS a").Select("a.*")
}).Find(&users3, "id IN ?", userIDs).Error; err != nil {
t.Errorf("failed to query, got error %v", err)
}
sort.Slice(users3, func(i, j int) bool {
return users2[i].ID < users2[j].ID
})
for i, u := range users3 {
CheckUser(t, u, users[i])
}
}
func TestNestedPreloadWithConds(t *testing.T) {

View File

@ -103,6 +103,24 @@ func TestFind(t *testing.T) {
})
}
func TestQueryWithAssociation(t *testing.T) {
user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3})
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create user: %v", err)
}
user.CreatedAt = time.Time{}
user.UpdatedAt = time.Time{}
if err := DB.Where(&user).First(&User{}).Error; err != nil {
t.Errorf("search with struct with association should returns no error, but got %v", err)
}
if err := DB.Where(user).First(&User{}).Error; err != nil {
t.Errorf("search with struct with association should returns no error, but got %v", err)
}
}
func TestFindInBatches(t *testing.T) {
var users = []User{
*GetUser("find_in_batches", Config{}),
@ -328,6 +346,11 @@ func TestSelect(t *testing.T) {
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
}
r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID)
if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) {
t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String())
}
}
func TestOmit(t *testing.T) {

View File

@ -1,16 +1,20 @@
package tests_test
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
@ -27,6 +31,7 @@ func TestScannerValuer(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}},
Password: EncryptedData("pass1"),
Bytes: []byte("byte"),
Num: 18,
@ -35,7 +40,9 @@ func TestScannerValuer(t *testing.T) {
{"name1", "value1"},
{"name2", "value2"},
},
Role: Role{Name: "admin"},
Role: Role{Name: "admin"},
ExampleStruct: ExampleStruct{"name", "value1"},
ExampleStructPtr: &ExampleStruct{"name", "value2"},
}
if err := DB.Create(&data).Error; err != nil {
@ -44,10 +51,17 @@ func TestScannerValuer(t *testing.T) {
var result ScannerValuerStruct
if err := DB.Find(&result).Error; err != nil {
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
}
if result.ExampleStructPtr.Val != "value2" {
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
}
if result.ExampleStruct.Val != "value1" {
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
}
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
}
@ -58,9 +72,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) {
}
data := ScannerValuerStruct{
Name: sql.NullString{String: "name", Valid: true},
Gender: &sql.NullString{String: "M", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: true},
Name: sql.NullString{String: "name", Valid: true},
Gender: &sql.NullString{String: "M", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: true},
ExampleStruct: ExampleStruct{"name", "value1"},
ExampleStructPtr: &ExampleStruct{"name", "value2"},
}
var result ScannerValuerStruct
@ -99,7 +115,9 @@ func TestInvalidValuer(t *testing.T) {
}
data := ScannerValuerStruct{
Password: EncryptedData("xpass1"),
Password: EncryptedData("xpass1"),
ExampleStruct: ExampleStruct{"name", "value1"},
ExampleStructPtr: &ExampleStruct{"name", "value2"},
}
if err := DB.Create(&data).Error; err == nil {
@ -124,21 +142,24 @@ func TestInvalidValuer(t *testing.T) {
type ScannerValuerStruct struct {
gorm.Model
Name sql.NullString
Gender *sql.NullString
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
Birthday sql.NullTime
Password EncryptedData
Bytes []byte
Num Num
Strings StringsSlice
Structs StructsSlice
Role Role
UserID *sql.NullInt64
User User
EmptyTime EmptyTime
Name sql.NullString
Gender *sql.NullString
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
Birthday sql.NullTime
Allergen NullString
Password EncryptedData
Bytes []byte
Num Num
Strings StringsSlice
Structs StructsSlice
Role Role
UserID *sql.NullInt64
User User
EmptyTime EmptyTime
ExampleStruct ExampleStruct
ExampleStructPtr *ExampleStruct
}
type EncryptedData []byte
@ -203,8 +224,32 @@ func (l *StringsSlice) Scan(input interface{}) error {
}
type ExampleStruct struct {
Name string
Value string
Name string
Val string
}
func (ExampleStruct) GormDataType() string {
return "bytes"
}
func (s ExampleStruct) Value() (driver.Value, error) {
if len(s.Name) == 0 {
return nil, nil
}
// for test, has no practical meaning
s.Name = ""
return json.Marshal(s)
}
func (s *ExampleStruct) Scan(src interface{}) error {
switch value := src.(type) {
case string:
return json.Unmarshal([]byte(value), s)
case []byte:
return json.Unmarshal(value, s)
default:
return errors.New("not supported")
}
}
type StructsSlice []ExampleStruct
@ -260,3 +305,89 @@ func (t *EmptyTime) Scan(v interface{}) error {
func (t EmptyTime) Value() (driver.Value, error) {
return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil
}
type NullString struct {
sql.NullString
}
type Point struct {
X, Y int
}
func (point Point) GormDataType() string {
return "geo"
}
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
return clause.Expr{
SQL: "ST_PointFromText(?)",
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
}
}
func TestGORMValuer(t *testing.T) {
type UserWithPoint struct {
Name string
Point Point
}
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Create(&UserWithPoint{
Name: "jinzhu",
Point: Point{X: 100, Y: 100},
}).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
}
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
}
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
}
stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{
"Name": "jinzhu",
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
}).Statement
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
}
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
}
stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{
"Name": "jinzhu",
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
}).Statement
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
}
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
}
stmt = dryRunDB.Session(&gorm.Session{
AllowGlobalUpdate: true,
}).Model(&UserWithPoint{}).Updates(UserWithPoint{
Name: "jinzhu",
Point: Point{X: 100, Y: 100},
}).Statement
if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) {
t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String())
}
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
}
}

View File

@ -40,6 +40,17 @@ func TestTable(t *testing.T) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
r = dryDB.Create(&UserWithTable{}).Statement
if DB.Dialector.Name() != "sqlite" {
if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
} else {
if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
}
r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement
if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())

View File

@ -1,6 +1,7 @@
package tests_test
import (
"context"
"errors"
"testing"
@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) {
}
}
func TestCancelTransaction(t *testing.T) {
ctx := context.Background()
ctx, cancelFunc := context.WithCancel(ctx)
cancelFunc()
user := *GetUser("cancel_transaction", Config{})
DB.Create(&user)
err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var result User
tx.First(&result, user.ID)
return nil
})
if err == nil {
t.Fatalf("Transaction should get error when using cancelled context")
}
}
func TestTransactionWithBlock(t *testing.T) {
assertPanic := func(f func()) {
defer func() {

View File

@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) {
if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
t.Errorf("should returns missing WHERE clause while updating error, got err %v", err)
}
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil {
t.Errorf("should returns no error while enable global update, but got err %v", err)
}
}
func TestSelectWithUpdate(t *testing.T) {
@ -334,6 +338,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) {
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
}
func TestWithUpdateWithInvalidMap(t *testing.T) {
user := *GetUser("update_with_invalid_map", Config{})
DB.Create(&user)
if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) {
t.Errorf("should returns error for unsupported updating data")
}
}
func TestOmitWithUpdate(t *testing.T) {
user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
@ -545,3 +558,31 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) {
t.Errorf("element's ignored field should not be updated")
}
}
func TestUpdateFromSubQuery(t *testing.T) {
user := *GetUser("update_from_sub_query", Config{Company: true})
if err := DB.Create(&user).Error; err != nil {
t.Errorf("failed to create user, got error: %v", err)
}
if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil {
t.Errorf("failed to update with sub query, got error %v", err)
}
var result User
DB.First(&result, user.ID)
if result.Name != user.Company.Name {
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
}
DB.Model(&user.Company).Update("Name", "new company name")
if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil {
t.Errorf("failed to update with sub query, got error %v", err)
}
DB.First(&result, user.ID)
if result.Name != "new company name" {
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
}
}

View File

@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
}
} else {
name := reflect.ValueOf(got).Type().Elem().Name()
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got)
}
return
}

View File

@ -29,8 +29,8 @@ func FileWithLineNum() string {
return ""
}
func IsChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*'
func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
}
func CheckTruth(val interface{}) bool {

14
utils/utils_test.go Normal file
View File

@ -0,0 +1,14 @@
package utils
import (
"strings"
"testing"
)
func TestIsValidDBNameChar(t *testing.T) {
for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} {
if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 {
t.Fatalf("failed to parse db name %v", db)
}
}
}