diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b626ce94..4388c31d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest, macos-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -38,7 +38,7 @@ jobs: sqlite_windows: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [windows-latest] runs-on: ${{ matrix.platform }} @@ -64,7 +64,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -108,7 +108,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run in macOS and widnowsOS runs-on: ${{ matrix.platform }} @@ -150,7 +150,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/README.md b/README.md index b51297c4..9c0aded0 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,11 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly @@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) - diff --git a/association.go b/association.go index e59b8938..db77cc4e 100644 --- a/association.go +++ b/association.go @@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - tx = association.DB.Model(out) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE", "LIMIT") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: queryConds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: queryConds}) - } - - association.Error = tx.Find(out, conds...).Error + association.Error = association.buildCondition().Find(out, conds...).Error } - return association.Error } @@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error { association.Error = association.Replace(values...) } default: - association.saveAssociation(false, values...) + association.saveAssociation( /*clear*/ false, values...) } } @@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation(true, values...) + association.saveAssociation( /*clear*/ true, values...) // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -112,7 +88,7 @@ func (association *Association) Replace(values ...interface{}) error { updateMap[ref.ForeignKey.DBName] = nil } - association.DB.UpdateColumns(updateMap) + association.Error = association.DB.UpdateColumns(updateMap).Error } case schema.HasOne, schema.HasMany: var ( @@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error { var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) for _, ref := range rel.References { @@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } if association.Error == nil { + // clean up deleted values's foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -328,33 +305,8 @@ func (association *Association) Clear() error { func (association *Association) Count() (count int64) { if association.Error == nil { - var ( - conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() - tx = association.DB.Model(modelValue) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE", "LIMIT") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: conds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: conds}) - } - - association.Error = tx.Count(&count).Error + association.Error = association.buildCondition().Count(&count).Error } - return } @@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { + // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { @@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: + // clear old data if clear && len(values) == 0 { association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) @@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } } + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +} diff --git a/callbacks/associations.go b/callbacks/associations.go index 3508335a..2710ffe9 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } } + } else { + break } } @@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) - } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - elems = reflect.Append(elems, rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } } } @@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) @@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) diff --git a/callbacks/create.go b/callbacks/create.go index 3a414dd7..c59b14b5 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { - if i, ok := value.(gorm.BeforeCreateInterface); ok { + if i, ok := value.(BeforeCreateInterface); ok { called = true db.AddError(i.BeforeCreate(tx)) } @@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID-- } } } else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID++ } } @@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) @@ -189,14 +203,14 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { - if i, ok := value.(gorm.AfterCreateInterface); ok { + if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } @@ -211,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) @@ -295,6 +313,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 288f2d69..e95117a1 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -11,7 +11,7 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { + if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) return true } @@ -41,7 +41,7 @@ func Delete(db *gorm.DB) { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) @@ -51,15 +51,15 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } - db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build("DELETE", "FROM", "WHERE") } + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -75,7 +75,7 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterDeleteInterface); ok { + if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) return true } diff --git a/callbacks/helper.go b/callbacks/helper.go index 7bd910f6..e0a66dd2 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -20,12 +20,18 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter for _, k := range keys { value := mapValue[k] - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + values.Values[0] = append(values.Values[0], value) } } @@ -42,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st for idx, mapValue := range mapValues { for k, v := range mapValue { - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if _, ok := result[k]; !ok { @@ -61,11 +69,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st sort.Strings(columns) values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + for i, v := range result[column] { - if i == 0 { + if len(values.Values[i]) == 0 { values.Values[i] = make([]interface{}, len(columns)) } + values.Values[i][idx] = v } } diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/preload.go b/callbacks/preload.go index cd09a6d6..25b8cb2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) - tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { @@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) fieldValues := make([]interface{}, len(relForeignFields)) diff --git a/callbacks/query.go b/callbacks/query.go index be829fbc..0703b92e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) { // inline joins if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} @@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} - for name, conds := range db.Statement.Joins { + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name for _, s := range relation.FieldSchema.DBNames { @@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) } } @@ -214,7 +214,7 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterFindInterface); ok { + if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) return true } diff --git a/callbacks/update.go b/callbacks/update.go index 12806af6..46f59157 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(gorm.BeforeUpdateInterface); ok { + if i, ok := value.(BeforeUpdateInterface); ok { called = true db.AddError(i.BeforeUpdate(tx)) } @@ -69,7 +69,7 @@ func Update(db *gorm.DB) { db.Statement.Build("UPDATE", "SET", "WHERE") } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } @@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { - if i, ok := value.(gorm.AfterUpdateInterface); ok { + if i, ok := value.(AfterUpdateInterface); ok { called = true db.AddError(i.AfterUpdate(tx)) } @@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { if field.DBName != "" { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) assignValue(field, value[k]) } } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { @@ -189,12 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) } } if !stmt.UpdatingColumn && stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { now := stmt.DB.NowFunc() @@ -217,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) @@ -247,6 +254,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/chainable_api.go b/chainable_api.go index 4df8780e..c8417a6d 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } - return tx + return } // Select specify fields that you want when querying, creating, updating @@ -91,8 +91,9 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsChar) + fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + + delete(tx.Statement.Clauses, "SELECT") } else { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, @@ -130,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) } else { tx.Statement.Omits = columns } @@ -169,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if tx.Statement.Joins == nil { - tx.Statement.Joins = map[string][]interface{}{} - } - tx.Statement.Joins[query] = args + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } @@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsChar) + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/errors.go b/errors.go index 115b8e25..32ff8ec1 100644 --- a/errors.go +++ b/errors.go @@ -19,6 +19,8 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered diff --git a/finisher_api.go b/finisher_api.go index 33a4f121..2cde3c31 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -132,19 +133,46 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return } -func (tx *DB) assignExprsToValue(exprs []clause.Expression) { - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + } } - case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + exprs := tx.Statement.BuildCondition(value) + tx.assignInterfacesToValue(exprs) + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } } - default: + } else if len(values) > 0 { + exprs := tx.Statement.BuildCondition(values[0], values[1:]...) + tx.assignInterfacesToValue(exprs) + return } } } @@ -154,22 +182,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } tx.Error = nil } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return } @@ -180,20 +206,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return tx.Create(dest) @@ -265,11 +289,14 @@ func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() } if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - defer tx.Statement.AddClause(clause.Select{}) + defer delete(tx.Statement.Clauses, "SELECT") } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} @@ -335,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsChar) + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -420,7 +447,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // Commit commit a transaction func (db *DB) Commit() *DB { - if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -431,7 +458,9 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { - db.AddError(committer.Rollback()) + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } } else { db.AddError(ErrInvalidTransaction) } diff --git a/gorm.go b/gorm.go index 1ace0099..3c187f42 100644 --- a/gorm.go +++ b/gorm.go @@ -32,6 +32,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -61,6 +63,7 @@ type Session struct { PrepareStmt bool WithConditions bool SkipDefaultTransaction bool + AllowGlobalUpdate bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -154,6 +157,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.SkipDefaultTransaction = true } + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/interfaces.go b/interfaces.go index b2ce59b3..e933952b 100644 --- a/interfaces.go +++ b/interfaces.go @@ -53,38 +53,7 @@ type TxCommitter interface { Rollback() error } -type BeforeCreateInterface interface { - BeforeCreate(*DB) error -} - -type AfterCreateInterface interface { - AfterCreate(*DB) error -} - -type BeforeUpdateInterface interface { - BeforeUpdate(*DB) error -} - -type AfterUpdateInterface interface { - AfterUpdate(*DB) error -} - -type BeforeSaveInterface interface { - BeforeSave(*DB) error -} - -type AfterSaveInterface interface { - AfterSave(*DB) error -} - -type BeforeDeleteInterface interface { - BeforeDelete(*DB) error -} - -type AfterDeleteInterface interface { - AfterDelete(*DB) error -} - -type AfterFindInterface interface { - AfterFind(*DB) error +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr } diff --git a/migrator.go b/migrator.go index 37051f81..ed8a8e26 100644 --- a/migrator.go +++ b/migrator.go @@ -42,6 +42,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index d50159dd..d93b8a6d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "gorm.io/gorm" @@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { - // TODO smart migrate data type for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { @@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, field := range stmt.Schema.FieldsByDBName { - if !tx.Migrator().HasColumn(value, field.DBName) { + var foundColumn *sql.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err } } @@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { - fmt.Println(err) return err } } @@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) + if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..524d19fb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ type Field struct { Comment string Size int Precision int + Scale int FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -88,53 +89,44 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) - } - - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) - } - - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) - } - - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) - } - // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { - var overrideFieldValue bool - if v, err := valuer.Value(); v != nil && err == nil { - overrideFieldValue = true + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { fieldValue = reflect.ValueOf(v) } - if field.IndirectFieldType.Kind() == reflect.Struct { - for i := 0; i < field.IndirectFieldType.NumField(); i++ { - if !overrideFieldValue { - newFieldType := field.IndirectFieldType.Field(i).Type + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - overrideFieldValue = true - } - // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } } } } } + + getRealFieldValue(fieldValue) } } @@ -169,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } @@ -311,15 +307,19 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false field.Updatable = false field.Readable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } + for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema @@ -341,17 +341,20 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.PrimaryKey = true } else { ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } for k, v := range field.TagSettings { ef.TagSettings[k] = v } } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } else { schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } @@ -475,16 +478,16 @@ func (field *Field) setupValuerAndSetter() { } } - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = setter(value, v) - } - } else if reflectV.Kind() == reflect.Ptr { + if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { err = setter(value, reflectV.Elem().Interface()) } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(value, v) + } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } @@ -731,40 +734,10 @@ func (field *Field) setupValuerAndSetter() { return nil } default: - if _, ok := fieldValue.Interface().(sql.Scanner); ok { - // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - - reflectV := reflect.ValueOf(v) - if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else if reflectV.Kind() == reflect.Ptr { - if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(value, reflectV.Elem().Interface()) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - - if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil || reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - v, _ = valuer.Value() - } - } - if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { @@ -778,10 +751,38 @@ func (field *Field) setupValuerAndSetter() { if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return fallbackSetter(value, v, field.Set) diff --git a/schema/interfaces.go b/schema/interfaces.go index f5d07843..98abffbd 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -1,23 +1,25 @@ package schema -import "gorm.io/gorm/clause" +import ( + "gorm.io/gorm/clause" +) type GormDataTypeInterface interface { GormDataType() string } type CreateClausesInterface interface { - CreateClauses() []clause.Interface + CreateClauses(*Field) []clause.Interface } type QueryClausesInterface interface { - QueryClauses() []clause.Interface + QueryClauses(*Field) []clause.Interface } type UpdateClausesInterface interface { - UpdateClauses() []clause.Interface + UpdateClauses(*Field) []clause.Interface } type DeleteClausesInterface interface { - DeleteClauses() []clause.Interface + DeleteClauses(*Field) []clause.Interface } diff --git a/schema/model_test.go b/schema/model_test.go index a13372b5..1f2b0948 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -39,3 +39,24 @@ type AdvancedDataTypeUser struct { Active mybool Admin *mybool } + +type BaseModel struct { + ID uint + CreatedAt time.Time + CreatedBy *int + Created *VersionUser `gorm:"foreignKey:CreatedBy"` + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +type VersionModel struct { + BaseModel + Version int +} + +type VersionUser struct { + VersionModel + Name string + Age uint + Birthday *time.Time +} diff --git a/schema/relationship.go b/schema/relationship.go index 93080105..5132ff74 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,7 +5,6 @@ import ( "reflect" "regexp" "strings" - "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -67,16 +66,14 @@ func (schema *Schema) parseRelation(field *Field) { } ) + cacheStore := schema.cacheStore if field.OwnerSchema != nil { - if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { - schema.err = err - return - } - } else { - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return - } + cacheStore = field.OwnerSchema.cacheStore + } + + if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { @@ -85,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { - case reflect.Struct, reflect.Slice: + case reflect.Struct: + schema.guessRelation(relation, field, guessBelongs) + case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -228,7 +227,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -251,7 +250,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -327,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessHas guessLevel = iota - guessEmbeddedHas - guessBelongs + guessBelongs guessLevel = iota guessEmbeddedBelongs + guessHas + guessEmbeddedHas ) func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { @@ -339,34 +338,36 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue primarySchema, foreignSchema = schema, relation.FieldSchema ) - reguessOrErr := func(err string, args ...interface{}) { + reguessOrErr := func() { switch gl { - case guessHas: - schema.guessRelation(relation, field, guessEmbeddedHas) - case guessEmbeddedHas: - schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: default: - schema.err = fmt.Errorf(err, args...) + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } switch gl { - case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) - return - } case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: if field.OwnerSchema != nil { primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() + return + } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() return } } @@ -376,7 +377,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if f := foreignSchema.LookUpField(foreignKey); f != nil { foreignFields = append(foreignFields, f) } else { - reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys) + reguessOrErr() return } } @@ -395,7 +396,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -403,11 +404,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } else { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } @@ -417,7 +418,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) } else { - reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name) + reguessOrErr() return } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2c09f528..2e85c538 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema.go b/schema/schema.go index 9206c24e..458256d1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -212,11 +212,29 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { - // parse relations for unidentified fields - for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } + } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) } } } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index f202b487..cc0306e0 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -49,10 +49,15 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } } - if parsedField, ok := s.FieldsByName[f.Name]; !ok { + parsedField, ok := s.FieldsByDBName[f.DBName] + if !ok { + parsedField, ok = s.FieldsByName[f.Name] + } + + if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { @@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* for _, name := range []string{f.DBName, f.Name} { if name != "" { - if field := s.LookUpField(name); field == nil || parsedField != field { + if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 99781e47..4d13ebd2 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -4,6 +4,7 @@ import ( "sync" "testing" + "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) @@ -160,3 +161,65 @@ func TestCustomizeTableName(t *testing.T) { t.Errorf("Failed to customize table with TableName method") } } + +func TestNestedModel(t *testing.T) { + versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse nested user, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, + {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, + } + + for _, f := range fields { + checkSchemaField(t, versionUser, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} + +func TestEmbeddedStruct(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} diff --git a/schema/utils.go b/schema/utils.go index 1481d428..41bd9d60 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -9,6 +9,8 @@ import ( "gorm.io/gorm/utils" ) +var embeddedCacheKey = "embedded_cache_store" + func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) @@ -49,8 +51,11 @@ func toColumns(val string) (results []string) { return } -func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { - return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag } // GetRelationsValues get relations's values from a reflect value diff --git a/soft_delete.go b/soft_delete.go index 6b88b1a5..484f565c 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) { return n.Time, nil } -func (DeletedAt) QueryClauses() []clause.Interface { +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{ clause.Where{Exprs: []clause.Expression{ clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: nil, }, }}, } } -func (DeletedAt) DeleteClauses() []clause.Interface { - return []clause.Interface{SoftDeleteClause{}} +type SoftDeleteQueryClause struct { + Field *schema.Field } -type SoftDeleteClause struct { -} - -func (SoftDeleteClause) Name() string { +func (sd SoftDeleteQueryClause) Name() string { return "" } -func (SoftDeleteClause) Build(clause.Builder) { +func (sd SoftDeleteQueryClause) Build(clause.Builder) { } -func (SoftDeleteClause) MergeClause(*clause.Clause) { +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } -func (SoftDeleteClause) ModifyStatement(stmt *Statement) { +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) @@ -64,7 +88,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if stmt.Dest != stmt.Model && stmt.Model != nil { + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) @@ -74,11 +98,6 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !ok { - stmt.DB.AddError(ErrMissingWhereClause) - return - } - stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build("UPDATE", "SET", "WHERE") } diff --git a/statement.go b/statement.go index e9d826c4..fba1991d 100644 --- a/statement.go +++ b/statement.go @@ -29,7 +29,7 @@ type Statement struct { Distinct bool Selects []string // selected columns Omits []string // omit columns - Joins map[string][]interface{} + Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool @@ -44,6 +44,11 @@ type Statement struct { assigns []interface{} } +type join struct { + Name string + Conds []interface{} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -156,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: var varStr strings.Builder var sql = v.SQL @@ -309,10 +316,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } @@ -322,10 +329,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } @@ -374,7 +381,6 @@ func (stmt *Statement) Build(clauses ...string) { } } } - // TODO handle named vars } func (stmt *Statement) Parse(value interface{}) (err error) { @@ -392,6 +398,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) clone() *Statement { newStmt := &Statement{ + TableExpr: stmt.TableExpr, Table: stmt.Table, Model: stmt.Model, Dest: stmt.Dest, @@ -400,7 +407,6 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, - Joins: map[string][]interface{}{}, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, @@ -416,8 +422,9 @@ func (stmt *Statement) clone() *Statement { newStmt.Preloads[k] = p } - for k, j := range stmt.Joins { - newStmt.Joins[k] = j + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) } stmt.Settings.Range(func(k, v interface{}) bool { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index d8befd8a..173e9231 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) { DB.Model(&user2).Association("Pets").Find(&user2.Pets) CheckUser(t, user2, user) + var pets []Pet + DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) + + if len(pets) != 1 { + t.Fatalf("should only find one pets, but got %v", len(pets)) + } + + CheckPet(t, pets[0], *user.Pets[0]) + + if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { + t.Fatalf("should only find one pets, but got %v", count) + } + + if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { + t.Fatalf("should only find no pet with invalid conditions, but got %v", count) + } + // Count AssertAssociationCount(t, user, "Pets", 2, "") @@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} - if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } - for _, pet := range pets { + for _, pet := range pets2 { var pet = pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") diff --git a/tests/count_test.go b/tests/count_test.go index 05661ae8..216fa3a1 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -67,4 +67,9 @@ func TestCount(t *testing.T) { if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } + + var count4 int64 + if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count) + } } diff --git a/tests/create_test.go b/tests/create_test.go index ae6e1232..ab0a78d4 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -39,6 +39,45 @@ func TestCreate(t *testing.T) { } } +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + func TestCreateWithAssociations(t *testing.T) { var user = *GetUser("create_with_associations", Config{ Account: true, diff --git a/tests/default_value_test.go b/tests/default_value_test.go index ea496d60..aa4a511a 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model Email string `gorm:"not null;index:,unique"` - Name string `gorm:"not null;default:'foo'"` + Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` Age int `gorm:"default:18"` diff --git a/tests/delete_test.go b/tests/delete_test.go index 3d461f65..17299677 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -43,6 +43,18 @@ func TestDelete(t *testing.T) { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } + + if err := DB.Delete(users[0]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } } func TestDeleteWithTable(t *testing.T) { @@ -110,4 +122,8 @@ func TestBlockGlobalDelete(t *testing.T) { if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while deleting error") } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index fb0d6f23..c29078bd 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) { Advanced bool } - DB.Debug().Migrator().DropTable(&AdvancedUser{}) + DB.Migrator().DropTable(&AdvancedUser{}) - if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { t.Errorf("Failed to auto migrate advanced user, got error %v", err) } } diff --git a/tests/go.mod b/tests/go.mod index 6eb6eb07..1a6fe7a8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.1 - gorm.io/driver/postgres v0.2.6 - gorm.io/driver/sqlite v1.0.8 - gorm.io/driver/sqlserver v0.2.5 - gorm.io/gorm v0.2.19 + gorm.io/driver/mysql v1.0.0 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.0 + gorm.io/driver/sqlserver v1.0.0 + gorm.io/gorm v1.9.19 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index e54d3784..f78ddf67 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "sort" "testing" @@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) { if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } } func TestJoinsWithSelect(t *testing.T) { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 1b002049..4cc8a7c3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) { } } +func TestSmartMigrateColumn(t *testing.T) { + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + +} + func TestMigrateWithComment(t *testing.T) { type UserWithComment struct { gorm.Model diff --git a/tests/preload_test.go b/tests/preload_test.go index 3caa17b4..7e5d2622 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) { } CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } } func TestNestedPreloadWithConds(t *testing.T) { diff --git a/tests/query_test.go b/tests/query_test.go index 59f1130b..d71c813a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -103,6 +103,24 @@ func TestFind(t *testing.T) { }) } +func TestQueryWithAssociation(t *testing.T) { + user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } +} + func TestFindInBatches(t *testing.T) { var users = []User{ *GetUser("find_in_batches", Config{}), @@ -328,6 +346,11 @@ func TestSelect(t *testing.T) { if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) } + + r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } } func TestOmit(t *testing.T) { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 63a7c63c..fb1f5791 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -1,16 +1,20 @@ package tests_test import ( + "context" "database/sql" "database/sql/driver" "encoding/json" "errors" + "fmt" "reflect" + "regexp" "strconv" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -27,6 +31,7 @@ func TestScannerValuer(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, Password: EncryptedData("pass1"), Bytes: []byte("byte"), Num: 18, @@ -35,7 +40,9 @@ func TestScannerValuer(t *testing.T) { {"name1", "value1"}, {"name2", "value2"}, }, - Role: Role{Name: "admin"}, + Role: Role{Name: "admin"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err != nil { @@ -44,10 +51,17 @@ func TestScannerValuer(t *testing.T) { var result ScannerValuerStruct - if err := DB.Find(&result).Error; err != nil { + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") } @@ -58,9 +72,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) { } data := ScannerValuerStruct{ - Name: sql.NullString{String: "name", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } var result ScannerValuerStruct @@ -99,7 +115,9 @@ func TestInvalidValuer(t *testing.T) { } data := ScannerValuerStruct{ - Password: EncryptedData("xpass1"), + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err == nil { @@ -124,21 +142,24 @@ func TestInvalidValuer(t *testing.T) { type ScannerValuerStruct struct { gorm.Model - Name sql.NullString - Gender *sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - Birthday sql.NullTime - Password EncryptedData - Bytes []byte - Num Num - Strings StringsSlice - Structs StructsSlice - Role Role - UserID *sql.NullInt64 - User User - EmptyTime EmptyTime + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Allergen NullString + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct } type EncryptedData []byte @@ -203,8 +224,32 @@ func (l *StringsSlice) Scan(input interface{}) error { } type ExampleStruct struct { - Name string - Value string + Name string + Val string +} + +func (ExampleStruct) GormDataType() string { + return "bytes" +} + +func (s ExampleStruct) Value() (driver.Value, error) { + if len(s.Name) == 0 { + return nil, nil + } + // for test, has no practical meaning + s.Name = "" + return json.Marshal(s) +} + +func (s *ExampleStruct) Scan(src interface{}) error { + switch value := src.(type) { + case string: + return json.Unmarshal([]byte(value), s) + case []byte: + return json.Unmarshal(value, s) + default: + return errors.New("not supported") + } } type StructsSlice []ExampleStruct @@ -260,3 +305,89 @@ func (t *EmptyTime) Scan(v interface{}) error { func (t EmptyTime) Value() (driver.Value, error) { return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } + +type NullString struct { + sql.NullString +} + +type Point struct { + X, Y int +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Session(&gorm.Session{ + AllowGlobalUpdate: true, + }).Model(&UserWithPoint{}).Updates(UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { + t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +} diff --git a/tests/table_test.go b/tests/table_test.go index faee6499..647b5e19 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -40,6 +40,17 @@ func TestTable(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c101388a..aea151d9 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "errors" "testing" @@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) { } } +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { diff --git a/tests/update_test.go b/tests/update_test.go index 2ff150dd..e52dc652 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) { if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } func TestSelectWithUpdate(t *testing.T) { @@ -334,6 +338,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) { AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") } +func TestWithUpdateWithInvalidMap(t *testing.T) { + user := *GetUser("update_with_invalid_map", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error for unsupported updating data") + } +} + func TestOmitWithUpdate(t *testing.T) { user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) @@ -545,3 +558,31 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) { t.Errorf("element's ignored field should not be updated") } } + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } + + DB.Model(&user.Company).Update("Name", "new company name") + if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + DB.First(&result, user.ID) + if result.Name != "new company name" { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 0067d5c6..817e4b0b 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } else { name := reflect.ValueOf(got).Type().Elem().Name() - t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) } return } diff --git a/utils/utils.go b/utils/utils.go index e93f3055..71336f4b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -29,8 +29,8 @@ func FileWithLineNum() string { return "" } -func IsChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' } func CheckTruth(val interface{}) bool { diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..5737c511 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,14 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +}