diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 380231b9..af471d20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,7 +30,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -73,7 +73,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -116,7 +116,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -159,7 +159,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -202,7 +202,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -235,7 +235,7 @@ jobs: - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} diff --git a/callbacks/preload.go b/callbacks/preload.go index 15669c84..25ecfe76 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "sort" "strings" "gorm.io/gorm" @@ -82,27 +83,80 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { return names } -func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { - if relationships == nil { - return nil +// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. +// If the current relationship is embedded or joined, current query will be ignored. +// +//nolint:cyclop +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { + preloadMap := parsePreloadMap(db.Statement.Schema, preloads) + + // avoid random traversal of the map + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) } - preloadMap := parsePreloadMap(s, preloads) - for name := range preloadMap { - if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { - if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + sort.Strings(preloadNames) + + isJoined := func(name string) (joined bool, nestedJoins []string) { + for _, join := range joins { + if _, ok := relationships.Relations[join]; ok && name == join { + joined = true + continue + } + joinNames := strings.SplitN(join, ".", 2) + if len(joinNames) == 2 { + if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { + joined = true + nestedJoins = append(nestedJoins, joinNames[1]) + } + } + } + return joined, nestedJoins + } + + for _, name := range preloadNames { + if relations := relationships.EmbeddedRelations[name]; relations != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { - if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { - return err + if joined, nestedJoins := isJoined(name); joined { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } else { + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) + tx.Statement.ReflectValue = db.Statement.ReflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { + return err + } } } else { - return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) } } return nil } +func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { + tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if err := tx.Statement.Parse(dest); err != nil { + tx.AddError(err) + return tx + } + tx.Statement.ReflectValue = reflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + return tx +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/callbacks/query.go b/callbacks/query.go index e89dd199..2a82eaba 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -3,7 +3,6 @@ package callbacks import ( "fmt" "reflect" - "sort" "strings" "gorm.io/gorm" @@ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) { } db.Statement.AddClause(fromClause) - db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } @@ -272,38 +270,23 @@ func Preload(db *gorm.DB) { return } - preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) - preloadNames := make([]string, 0, len(preloadMap)) - for key := range preloadMap { - preloadNames = append(preloadNames, key) + joins := make([]string, 0, len(db.Statement.Joins)) + for _, join := range db.Statement.Joins { + joins = append(joins, join.Name) } - sort.Strings(preloadNames) - preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) - db.Statement.Settings.Range(func(k, v interface{}) bool { - preloadDB.Statement.Settings.Store(k, v) - return true - }) - - if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) + if tx.Error != nil { return } - preloadDB.Statement.ReflectValue = db.Statement.ReflectValue - preloadDB.Statement.Unscoped = db.Statement.Unscoped - for _, name := range preloadNames { - if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { - db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) - } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) - } else { - db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) - } - } + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } func AfterQuery(db *gorm.DB) { + // clear the joins after query because preload need it + db.Statement.Joins = nil if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { diff --git a/chainable_api.go b/chainable_api.go index 3dc7256e..1ec9b865 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -367,33 +367,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } func (db *DB) executeScopes() (tx *DB) { - tx = db.getInstance() scopes := db.Statement.scopes - if len(scopes) == 0 { - return tx - } - tx.Statement.scopes = nil - - conditions := make([]clause.Interface, 0, 4) - if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { - conditions = append(conditions, cs.Expression.(clause.Interface)) - cs.Expression = nil - tx.Statement.Clauses["WHERE"] = cs - } - + db.Statement.scopes = nil for _, scope := range scopes { - tx = scope(tx) - if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { - conditions = append(conditions, cs.Expression.(clause.Interface)) - cs.Expression = nil - tx.Statement.Clauses["WHERE"] = cs - } + db = scope(db) } - - for _, condition := range conditions { - tx.Statement.AddClause(condition) - } - return tx + return db } // Preload preload associations with given conditions diff --git a/clause/locking.go b/clause/locking.go index 290aac92..2bc48ceb 100644 --- a/clause/locking.go +++ b/clause/locking.go @@ -1,5 +1,12 @@ package clause +const ( + LockingStrengthUpdate = "UPDATE" + LockingStrengthShare = "SHARE" + LockingOptionsSkipLocked = "SKIP LOCKED" + LockingOptionsNoWait = "NOWAIT" +) + type Locking struct { Strength string Table Table diff --git a/clause/locking_test.go b/clause/locking_test.go index 0e607312..e45c8e7d 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -14,17 +14,21 @@ func TestLocking(t *testing.T) { Vars []interface{} }{ { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}}, "SELECT * FROM `users` FOR UPDATE", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}}, "SELECT * FROM `users` FOR SHARE OF `users`", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}}, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}}, + "SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil, + }, } for idx, result := range results { diff --git a/clause/where.go b/clause/where.go index a29401cf..46d0b319 100644 --- a/clause/where.go +++ b/clause/where.go @@ -21,6 +21,12 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } + // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { @@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } + if len(exprs) == 1 { + if andCondition, ok := exprs[0].(AndConditions); ok { + exprs = andCondition.Exprs + } + } return NotConditions{Exprs: exprs} } diff --git a/clause/where_test.go b/clause/where_test.go index 35e3dbee..aa9d06eb 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -63,7 +63,7 @@ func TestWhere(t *testing.T) { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, }}, - "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", + "SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?", []interface{}{18, "jinzhu"}, }, { @@ -94,7 +94,7 @@ func TestWhere(t *testing.T) { clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), }, }}, - "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + "SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?", []interface{}{"1", 100}, }, { diff --git a/finisher_api.go b/finisher_api.go index f80aa6c0..f97571ed 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -376,8 +376,12 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } else if len(db.Statement.assigns) > 0 { exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { + for i := 0; i < len(exprs); i++ { + expr := exprs[i] + + if eq, ok := expr.(clause.AndConditions); ok { + exprs = append(exprs, eq.Exprs...) + } else if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: assigns[column] = eq.Value diff --git a/logger/sql.go b/logger/sql.go index 13e5d957..8ce8d8b1 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -79,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper } else { vars[idx] = nullStr } } case []byte: if s := string(v); isPrintable(s) { - vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } @@ -100,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64: vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: - vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -117,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index a82fa546..036ef3a4 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper } func TestExplainSQL(t *testing.T) { @@ -40,7 +40,7 @@ func TestExplainSQL(t *testing.T) { var ( tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") - pwd = password([]byte("pass")) + pwd = password("pass") jsVal = []byte(`{"Name":"test","Val":"test"}`) js = JSON(jsVal) esVal = []byte(`{"Name":"test","Val":"test"}`) @@ -57,13 +57,13 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", @@ -87,25 +87,25 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, } diff --git a/schema/relationship.go b/schema/relationship.go index 0a939da8..2e94fc2c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return nil } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - schema.buildPolymorphicRelation(relation, field, polymorphic) + if hasPolymorphicRelation(field.TagSettings) { + schema.buildPolymorphicRelation(relation, field) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { @@ -89,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, + field.Name) } } @@ -124,6 +125,20 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +// hasPolymorphicRelation check if has polymorphic relation +// 1. `POLYMORPHIC` tag +// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag +func hasPolymorphicRelation(tagSettings map[string]string) bool { + if _, ok := tagSettings["POLYMORPHIC"]; ok { + return true + } + + _, hasType := tagSettings["POLYMORPHICTYPE"] + _, hasId := tagSettings["POLYMORPHICID"] + + return hasType && hasId +} + func (schema *Schema) setRelation(relation *Relationship) { // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { @@ -169,23 +184,41 @@ func (schema *Schema) setRelation(relation *Relationship) { // OwnerID int // OwnerType string // } -func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) { + polymorphic := field.TagSettings["POLYMORPHIC"] + relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + Value: schema.Table, } + var ( + typeName = polymorphic + "Type" + typeId = polymorphic + "ID" + ) + + if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok { + typeName = strings.TrimSpace(value) + } + + if value, ok := field.TagSettings["POLYMORPHICID"]; ok { + typeId = strings.TrimSpace(value) + } + + relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName] + relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId] + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -197,12 +230,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, + schema, field.Name) } } if primaryKeyField == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", + relation.FieldSchema, schema, field.Name) return } @@ -317,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, + schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many @@ -436,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", + schema, field.Name) } } @@ -492,7 +529,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", + strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, + strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 1eb66bb4..23d79bbb 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -577,6 +577,193 @@ func TestEmbeddedHas(t *testing.T) { }) } +func TestPolymorphic(t *testing.T) { + t.Run("has one", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with only polymorphic type", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) +} + func TestEmbeddedBelongsTo(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 605aa03a..bc326686 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) } - for _, f := range relation.JoinTable.Fields { - checkSchemaField(t, r.JoinTable, &f, nil) + for i := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 5bc0fb83..45e152e9 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -46,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } - for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) { + for i := range fields { + checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true @@ -136,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, } - for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) { + for i := range fields { + checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true diff --git a/statement.go b/statement.go index 59c0b772..ae79aa32 100644 --- a/statement.go +++ b/statement.go @@ -326,7 +326,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case *DB: v.executeScopes() - if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -334,13 +334,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else { + } else if cs.Expression != nil { conds = append(conds, cs.Expression) } - if v.Statement == stmt { - cs.Expression = nil - stmt.Statement.Clauses["WHERE"] = cs - } } case map[interface{}]interface{}: for i, j := range v { @@ -451,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + return []clause.Expression{clause.And(conds...)} } - return conds + return nil } } @@ -461,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } - return conds + if len(conds) > 0 { + return []clause.Expression{clause.And(conds...)} + } + return nil } // Build build sql with clauses names @@ -665,7 +665,21 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) +var matchName = func() func(tableColumn string) (table, column string) { + nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`) + return func(tableColumn string) (table, column string) { + if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 { + table = matches[1] + star := matches[2] + columnName := matches[3] + if star != "" { + return table, star + } + return table, columnName + } + return "", "" + } +}() // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { @@ -686,13 +700,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = result - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { - if matches[2] == "*" { + } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { + if col == "*" { for _, dbName := range stmt.Schema.DBNames { results[dbName] = result } } else { - results[matches[2]] = result + results[col] = result } } else { results[column] = result diff --git a/statement_test.go b/statement_test.go index 648bc875..0995d547 100644 --- a/statement_test.go +++ b/statement_test.go @@ -56,9 +56,15 @@ func TestNameMatcher(t *testing.T) { "`name_1`": {"", "name_1"}, "`Name_1`": {"", "Name_1"}, "`Table`.`nAme`": {"Table", "nAme"}, + "my_table.*": {"my_table", "*"}, + "`my_table`.*": {"my_table", "*"}, + "User__Company.*": {"User__Company", "*"}, + "`User__Company`.*": {"User__Company", "*"}, + `"User__Company".*`: {"User__Company", "*"}, + `"table"."*"`: {"", ""}, } { - if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { - t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + if table, column := matchName(k); table != v[0] || column != v[1] { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) } } } diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index c31c4b40..b8e8ff5e 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -422,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), - *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}), *GetUser("slice-hasmany-3", Config{Toys: 4}), } @@ -430,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { // Count AssertAssociationCount(t, users, "Toys", 6, "") + AssertAssociationCount(t, users, "Tools", 2, "") // Find var toys []Toy @@ -437,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { t.Errorf("toys count should be %v, but got %v", 6, len(toys)) } + // Find Tools (polymorphic with custom type and id) + var tools []Tools + DB.Model(&users).Association("Tools").Find(&tools) + + if len(tools) != 2 { + t.Errorf("tools count should be %v, but got %v", 2, len(tools)) + } + // Append DB.Model(&users).Association("Toys").Append( &Toy{Name: "toy-slice-append-1"}, diff --git a/tests/go.mod b/tests/go.mod index 1c12a43c..460d96a5 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,7 +3,7 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.4.0 + github.com/google/uuid v1.5.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.8.4 @@ -20,15 +20,15 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect - github.com/mattn/go-sqlite3 v1.14.18 // indirect + github.com/mattn/go-sqlite3 v1.14.19 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect - golang.org/x/crypto v0.15.0 // indirect + golang.org/x/crypto v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/helper_test.go b/tests/helper_test.go index 1a4874ee..feb67f9e 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -23,6 +23,7 @@ type Config struct { Languages int Friends int NamedPet bool + Tools int } func GetUser(name string, config Config) *User { @@ -47,6 +48,10 @@ func GetUser(name string, config Config) *User { user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) } + for i := 0; i < config.Tools; i++ { + user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)}) + } + if config.Company { user.Company = Company{Name: "company-" + name} } @@ -118,11 +123,13 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", + "CompanyID", "ManagerID", "Active") } } - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", + "ManagerID", "Active") t.Run("Account", func(t *testing.T) { AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") @@ -133,7 +140,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { } else { var account Account db(unscoped).First(&account, "user_id = ?", user.ID) - AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", + "Number") } } }) @@ -193,8 +201,10 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { } else { var manager User db(unscoped).First(&manager, "id = ?", *user.ManagerID) - AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } } else if user.ManagerID != nil { t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) @@ -215,7 +225,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { }) for idx, team := range user.Team { - AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) @@ -250,7 +261,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { }) for idx, friend := range user.Friends { - AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index cfd3e0ac..28fa315b 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -18,7 +18,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") @@ -34,7 +34,7 @@ func TestMigrate(t *testing.T) { if tables, err := DB.Migrator().GetTables(); err != nil { t.Fatalf("Failed to get database all tables, but got error %v", err) } else { - for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} { hasTable := false for _, t2 := range tables { if t2 == t1 { @@ -93,7 +93,8 @@ func TestAutoMigrateInt8PG(t *testing.T) { Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { - t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", + sql) } }, } @@ -432,40 +433,50 @@ func TestTiDBMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !ok || length != 100 { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !ok || v != "my age" { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !ok || v != "hello" { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": // Code2 string `gorm:"comment:my code2;default:hello"` if v, ok := columnType.DefaultValue(); !ok || v != "hello" { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } } } @@ -497,7 +508,8 @@ func TestTiDBMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -561,36 +573,45 @@ func TestMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { - t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code3": // TODO @@ -627,7 +648,8 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -1555,7 +1577,8 @@ func TestMigrateIgnoreRelations(t *testing.T) { func TestMigrateView(t *testing.T) { DB.Save(GetUser("joins-args-db", Config{Pets: 2})) - if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { + if err := DB.Migrator().CreateView("invalid_users_pets", + gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { t.Fatalf("no view should be created, got %v", err) } @@ -1624,17 +1647,20 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "string_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } case "smallint_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } } } @@ -1659,7 +1685,8 @@ func TestTableType(t *testing.T) { DB.Migrator().DropTable(&City{}) - if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { + if err := DB.Set("gorm:table_options", + fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities tables, got error: %v", err) } diff --git a/tests/preload_test.go b/tests/preload_test.go index 3ff86492..26b08d7d 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -307,6 +307,63 @@ func TestNestedPreloadWithUnscoped(t *testing.T) { CheckUserUnscoped(t, *user6, user) } +func TestNestedPreloadWithNestedJoin(t *testing.T) { + type ( + Preload struct { + ID uint + Value string + NestedID uint + } + Join struct { + ID uint + Value string + NestedID uint + } + Nested struct { + ID uint + Preloads []*Preload + Join Join + ValueID uint + } + Value struct { + ID uint + Name string + Nested Nested + } + ) + + DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) + DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) + + value := Value{ + Name: "value", + Nested: Nested{ + Preloads: []*Preload{ + {Value: "p1"}, {Value: "p2"}, + }, + Join: Join{Value: "j1"}, + }, + } + if err := DB.Create(&value).Error; err != nil { + t.Errorf("failed to create value, got err: %v", err) + } + + var find1 Value + err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find1, value) + + var find2 Value + // Joins will automatically add Nested queries. + err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find2, value) +} + func TestEmbedPreload(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` diff --git a/tests/query_test.go b/tests/query_test.go index 5728378d..cadf7164 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1118,12 +1118,12 @@ func TestSearchWithStruct(t *testing.T) { } result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) - if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) - if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 52c6b37b..84aeb990 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -84,7 +84,9 @@ func TestComplexScopes(t *testing.T) { queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, - func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, + func(d *gorm.DB) *gorm.DB { + return d.Where(DB.Or("b = 2").Or("c = 3")) + }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, @@ -93,7 +95,9 @@ func TestComplexScopes(t *testing.T) { queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Where("z = 0").Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, - func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, + func(d *gorm.DB) *gorm.DB { + return d.Or(DB.Where("b = 2").Or("c = 3")) + }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, @@ -104,7 +108,7 @@ func TestComplexScopes(t *testing.T) { func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, func(d *gorm.DB) *gorm.DB { return d. - Or(d.Scopes( + Or(DB.Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, )). diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 022e0495..0c204db4 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -388,7 +388,7 @@ func TestToSQL(t *testing.T) { sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) }) - assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) + assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) // last and unscoped sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { diff --git a/tests/tests_test.go b/tests/tests_test.go index f9c6cab5..a127734e 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/tests/update_test.go b/tests/update_test.go index b719cc45..9eb9dbfc 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -882,3 +882,52 @@ func TestSaveWithHooks(t *testing.T) { t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) } } + +// only postgres, sqlserver, sqlite support update from +func TestUpdateFrom(t *testing.T) { + if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" { + return + } + + users := []*User{ + GetUser("update-from-1", Config{Account: true}), + GetUser("update-from-2", Config{Account: true}), + GetUser("update-from-3", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if users[0].ID == 0 { + t.Fatalf("user's primary value should not zero, %v", users[0].ID) + } else if users[0].UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", users[0].UpdatedAt) + } + + if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number = ? AND accounts.deleted_at IS NULL", users[0].Account.Number).Update("name", "franco").RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + var result User + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } else if result.UpdatedAt.UnixNano() == users[0].UpdatedAt.UnixNano() { + t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, users[0].UpdatedAt) + } else if result.Name != "franco" { + t.Errorf("user's name should be updated") + } + + if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number IN ? AND accounts.deleted_at IS NULL", []string{users[0].Account.Number, users[1].Account.Number}).Update("name", gorm.Expr("accounts.number")).RowsAffected; rowsAffected != 2 { + t.Errorf("should update two records, but got %v", rowsAffected) + } + + var results []User + if err := DB.Preload("Account").Find(&results, []uint{users[0].ID, users[1].ID}).Error; err != nil { + t.Errorf("Not error should happen when finding users, but got %v", err) + } + + for _, user := range results { + if user.Name != user.Account.Number { + t.Errorf("user's name should be equal to the account's number %v, but got %v", user.Account.Number, user.Name) + } + } +} diff --git a/utils/tests/models.go b/utils/tests/models.go index a4bad2fc..f9f4f50e 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -20,7 +20,8 @@ type User struct { Account Account Pets []*Pet NamedPet *Pet - Toys []Toy `gorm:"polymorphic:Owner"` + Toys []Toy `gorm:"polymorphic:Owner"` + Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"` CompanyID *int Company Company ManagerID *uint @@ -51,6 +52,13 @@ type Toy struct { OwnerType string } +type Tools struct { + gorm.Model + Name string + CustomID string + Type string +} + type Company struct { ID int Name string diff --git a/utils/utils.go b/utils/utils.go index c8fec5b0..a4d8ac25 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -35,7 +35,8 @@ func FileWithLineNum() string { // the second caller usually from gorm internal, so set i start from 2 for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && + !strings.HasSuffix(file, ".gen.go") { return file + ":" + strconv.FormatInt(int64(line), 10) } }