Merge branch 'master' into distinguish_unique
# Conflicts: # tests/go.mod
This commit is contained in:
commit
184ec8e136
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@ -30,7 +30,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -73,7 +73,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -116,7 +116,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -159,7 +159,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -202,7 +202,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -235,7 +235,7 @@ jobs:
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
@ -3,6 +3,7 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -82,27 +83,80 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
return names
|
||||
}
|
||||
|
||||
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
|
||||
if relationships == nil {
|
||||
return nil
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
preloadMap := parsePreloadMap(s, preloads)
|
||||
for name := range preloadMap {
|
||||
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
|
||||
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
joinNames := strings.SplitN(join, ".", 2)
|
||||
if len(joinNames) == 2 {
|
||||
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, joinNames[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
|
||||
return err
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
|
@ -3,7 +3,6 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
db.Statement.Joins = nil
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
@ -272,38 +270,23 @@ func Preload(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
preloadDB.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
return
|
||||
}
|
||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
|
||||
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
|
||||
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
}
|
||||
}
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
db.Statement.Joins = nil
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
|
@ -367,33 +367,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
}
|
||||
|
||||
func (db *DB) executeScopes() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
scopes := db.Statement.scopes
|
||||
if len(scopes) == 0 {
|
||||
return tx
|
||||
}
|
||||
tx.Statement.scopes = nil
|
||||
|
||||
conditions := make([]clause.Interface, 0, 4)
|
||||
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||
conditions = append(conditions, cs.Expression.(clause.Interface))
|
||||
cs.Expression = nil
|
||||
tx.Statement.Clauses["WHERE"] = cs
|
||||
}
|
||||
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
tx = scope(tx)
|
||||
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||
conditions = append(conditions, cs.Expression.(clause.Interface))
|
||||
cs.Expression = nil
|
||||
tx.Statement.Clauses["WHERE"] = cs
|
||||
}
|
||||
db = scope(db)
|
||||
}
|
||||
|
||||
for _, condition := range conditions {
|
||||
tx.Statement.AddClause(condition)
|
||||
}
|
||||
return tx
|
||||
return db
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
|
@ -1,5 +1,12 @@
|
||||
package clause
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate = "UPDATE"
|
||||
LockingStrengthShare = "SHARE"
|
||||
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||
LockingOptionsNoWait = "NOWAIT"
|
||||
)
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
|
@ -14,17 +14,21 @@ func TestLocking(t *testing.T) {
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
|
||||
"SELECT * FROM `users` FOR UPDATE", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
|
||||
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
|
||||
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
@ -21,6 +21,12 @@ func (where Where) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
if len(where.Exprs) == 1 {
|
||||
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||
where.Exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(exprs) == 1 {
|
||||
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||
exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
|
||||
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
|
||||
[]interface{}{18, "jinzhu"},
|
||||
},
|
||||
{
|
||||
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
|
||||
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
|
||||
},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
|
||||
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
|
||||
[]interface{}{"1", 100},
|
||||
},
|
||||
{
|
||||
|
@ -376,8 +376,12 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
for i := 0; i < len(exprs); i++ {
|
||||
expr := exprs[i]
|
||||
|
||||
if eq, ok := expr.(clause.AndConditions); ok {
|
||||
exprs = append(exprs, eq.Exprs...)
|
||||
} else if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
|
@ -79,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if s := string(v); isPrintable(s) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
@ -100,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case float64:
|
||||
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
@ -117,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) {
|
||||
}
|
||||
|
||||
func format(v []byte, escaper string) string {
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
|
||||
}
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
@ -40,7 +40,7 @@ func TestExplainSQL(t *testing.T) {
|
||||
var (
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password([]byte("pass"))
|
||||
pwd = password("pass")
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
@ -57,13 +57,13 @@ func TestExplainSQL(t *testing.T) {
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||
@ -87,25 +87,25 @@ func TestExplainSQL(t *testing.T) {
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return nil
|
||||
}
|
||||
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
||||
if hasPolymorphicRelation(field.TagSettings) {
|
||||
schema.buildPolymorphicRelation(relation, field)
|
||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||
@ -89,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||
field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,6 +125,20 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return relation
|
||||
}
|
||||
|
||||
// hasPolymorphicRelation check if has polymorphic relation
|
||||
// 1. `POLYMORPHIC` tag
|
||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||
_, hasId := tagSettings["POLYMORPHICID"]
|
||||
|
||||
return hasType && hasId
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
@ -169,23 +184,41 @@ func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
|
||||
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
|
||||
Value: schema.Table,
|
||||
}
|
||||
|
||||
var (
|
||||
typeName = polymorphic + "Type"
|
||||
typeId = polymorphic + "ID"
|
||||
)
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||
typeName = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||
typeId = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
@ -197,12 +230,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||
relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
@ -317,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
@ -436,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -492,7 +529,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
|
||||
lookUpNames := []string{lookUpName}
|
||||
if len(primaryFields) == 1 {
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
|
@ -577,6 +577,193 @@ func TestEmbeddedHas(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolymorphic(t *testing.T) {
|
||||
t.Run("has one", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
RefId int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has one with only polymorphic type", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has many", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
RefId int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
|
@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
||||
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
|
||||
}
|
||||
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
checkSchemaField(t, r.JoinTable, &f, nil)
|
||||
for i := range relation.JoinTable.Fields {
|
||||
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||
for i := range fields {
|
||||
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
@ -136,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||
for i := range fields {
|
||||
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
|
38
statement.go
38
statement.go
@ -326,7 +326,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
case *DB:
|
||||
v.executeScopes()
|
||||
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
@ -334,13 +334,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
}
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
} else {
|
||||
} else if cs.Expression != nil {
|
||||
conds = append(conds, cs.Expression)
|
||||
}
|
||||
if v.Statement == stmt {
|
||||
cs.Expression = nil
|
||||
stmt.Statement.Clauses["WHERE"] = cs
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
for i, j := range v {
|
||||
@ -451,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return conds
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -461,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
}
|
||||
|
||||
return conds
|
||||
if len(conds) > 0 {
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
@ -665,7 +665,21 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
|
||||
var matchName = func() func(tableColumn string) (table, column string) {
|
||||
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||
return func(tableColumn string) (table, column string) {
|
||||
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||
table = matches[1]
|
||||
star := matches[2]
|
||||
columnName := matches[3]
|
||||
if star != "" {
|
||||
return table, star
|
||||
}
|
||||
return table, columnName
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
}()
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
@ -686,13 +700,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = result
|
||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
|
||||
if matches[2] == "*" {
|
||||
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||
if col == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[matches[2]] = result
|
||||
results[col] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = result
|
||||
|
@ -56,9 +56,15 @@ func TestNameMatcher(t *testing.T) {
|
||||
"`name_1`": {"", "name_1"},
|
||||
"`Name_1`": {"", "Name_1"},
|
||||
"`Table`.`nAme`": {"Table", "nAme"},
|
||||
"my_table.*": {"my_table", "*"},
|
||||
"`my_table`.*": {"my_table", "*"},
|
||||
"User__Company.*": {"User__Company", "*"},
|
||||
"`User__Company`.*": {"User__Company", "*"},
|
||||
`"User__Company".*`: {"User__Company", "*"},
|
||||
`"table"."*"`: {"", ""},
|
||||
} {
|
||||
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] {
|
||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
|
||||
if table, column := matchName(k); table != v[0] || column != v[1] {
|
||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -422,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
|
||||
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
users := []User{
|
||||
*GetUser("slice-hasmany-1", Config{Toys: 2}),
|
||||
*GetUser("slice-hasmany-2", Config{Toys: 0}),
|
||||
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
|
||||
*GetUser("slice-hasmany-3", Config{Toys: 4}),
|
||||
}
|
||||
|
||||
@ -430,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
|
||||
// Count
|
||||
AssertAssociationCount(t, users, "Toys", 6, "")
|
||||
AssertAssociationCount(t, users, "Tools", 2, "")
|
||||
|
||||
// Find
|
||||
var toys []Toy
|
||||
@ -437,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
|
||||
}
|
||||
|
||||
// Find Tools (polymorphic with custom type and id)
|
||||
var tools []Tools
|
||||
DB.Model(&users).Association("Tools").Find(&tools)
|
||||
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
|
||||
}
|
||||
|
||||
// Append
|
||||
DB.Model(&users).Association("Toys").Append(
|
||||
&Toy{Name: "toy-slice-append-1"},
|
||||
|
10
tests/go.mod
10
tests/go.mod
@ -3,7 +3,7 @@ module gorm.io/gorm/tests
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.4.0
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.8.4
|
||||
@ -20,15 +20,15 @@ require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.19 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ type Config struct {
|
||||
Languages int
|
||||
Friends int
|
||||
NamedPet bool
|
||||
Tools int
|
||||
}
|
||||
|
||||
func GetUser(name string, config Config) *User {
|
||||
@ -47,6 +48,10 @@ func GetUser(name string, config Config) *User {
|
||||
user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
|
||||
}
|
||||
|
||||
for i := 0; i < config.Tools; i++ {
|
||||
user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)})
|
||||
}
|
||||
|
||||
if config.Company {
|
||||
user.Company = Company{Name: "company-" + name}
|
||||
}
|
||||
@ -118,11 +123,13 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||
if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday",
|
||||
"CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
}
|
||||
|
||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID",
|
||||
"ManagerID", "Active")
|
||||
|
||||
t.Run("Account", func(t *testing.T) {
|
||||
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
||||
@ -133,7 +140,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||
} else {
|
||||
var account Account
|
||||
db(unscoped).First(&account, "user_id = ?", user.ID)
|
||||
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
||||
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID",
|
||||
"Number")
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -193,8 +201,10 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||
} else {
|
||||
var manager User
|
||||
db(unscoped).First(&manager, "id = ?", *user.ManagerID)
|
||||
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
} else if user.ManagerID != nil {
|
||||
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
|
||||
@ -215,7 +225,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||
})
|
||||
|
||||
for idx, team := range user.Team {
|
||||
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
})
|
||||
|
||||
@ -250,7 +261,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||
})
|
||||
|
||||
for idx, friend := range user.Friends {
|
||||
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}}
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||
DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")
|
||||
@ -34,7 +34,7 @@ func TestMigrate(t *testing.T) {
|
||||
if tables, err := DB.Migrator().GetTables(); err != nil {
|
||||
t.Fatalf("Failed to get database all tables, but got error %v", err)
|
||||
} else {
|
||||
for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} {
|
||||
for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} {
|
||||
hasTable := false
|
||||
for _, t2 := range tables {
|
||||
if t2 == t1 {
|
||||
@ -93,7 +93,8 @@ func TestAutoMigrateInt8PG(t *testing.T) {
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
|
||||
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
|
||||
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s",
|
||||
sql)
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -432,40 +433,50 @@ func TestTiDBMigrateColumns(t *testing.T) {
|
||||
switch columnType.Name() {
|
||||
case "id":
|
||||
if v, ok := columnType.PrimaryKey(); !ok || !v {
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "name":
|
||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
|
||||
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
}
|
||||
if length, ok := columnType.Length(); !ok || length != 100 {
|
||||
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
|
||||
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v",
|
||||
columnType.Name(), length, 100, columnType)
|
||||
}
|
||||
case "age":
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
|
||||
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !ok || v != "my age" {
|
||||
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "code":
|
||||
if v, ok := columnType.Unique(); !ok || !v {
|
||||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
|
||||
columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !ok || v != "my code2" {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "code2":
|
||||
// Code2 string `gorm:"comment:my code2;default:hello"`
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
|
||||
columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !ok || v != "my code2" {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -497,7 +508,8 @@ func TestTiDBMigrateColumns(t *testing.T) {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
|
||||
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName",
|
||||
"new_new_name"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
@ -561,36 +573,45 @@ func TestMigrateColumns(t *testing.T) {
|
||||
switch columnType.Name() {
|
||||
case "id":
|
||||
if v, ok := columnType.PrimaryKey(); !ok || !v {
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "name":
|
||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
|
||||
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
}
|
||||
if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
|
||||
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
|
||||
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v",
|
||||
columnType.Name(), length, 100, columnType)
|
||||
}
|
||||
case "age":
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
|
||||
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
|
||||
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "code":
|
||||
if v, ok := columnType.Unique(); !ok || !v {
|
||||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
|
||||
columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "code2":
|
||||
if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
|
||||
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "code3":
|
||||
// TODO
|
||||
@ -627,7 +648,8 @@ func TestMigrateColumns(t *testing.T) {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
|
||||
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName",
|
||||
"new_new_name"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
@ -1555,7 +1577,8 @@ func TestMigrateIgnoreRelations(t *testing.T) {
|
||||
func TestMigrateView(t *testing.T) {
|
||||
DB.Save(GetUser("joins-args-db", Config{Pets: 2}))
|
||||
|
||||
if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired {
|
||||
if err := DB.Migrator().CreateView("invalid_users_pets",
|
||||
gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired {
|
||||
t.Fatalf("no view should be created, got %v", err)
|
||||
}
|
||||
|
||||
@ -1624,17 +1647,20 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) {
|
||||
switch columnType.Name() {
|
||||
case "id":
|
||||
if v, ok := columnType.PrimaryKey(); !ok || !v {
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
|
||||
columnType)
|
||||
}
|
||||
case "string_bool":
|
||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
|
||||
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
}
|
||||
case "smallint_bool":
|
||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
|
||||
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1659,7 +1685,8 @@ func TestTableType(t *testing.T) {
|
||||
|
||||
DB.Migrator().DropTable(&City{})
|
||||
|
||||
if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
|
||||
if err := DB.Set("gorm:table_options",
|
||||
fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
|
||||
t.Fatalf("failed to migrate cities tables, got error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -307,6 +307,63 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
|
||||
CheckUserUnscoped(t, *user6, user)
|
||||
}
|
||||
|
||||
func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
||||
type (
|
||||
Preload struct {
|
||||
ID uint
|
||||
Value string
|
||||
NestedID uint
|
||||
}
|
||||
Join struct {
|
||||
ID uint
|
||||
Value string
|
||||
NestedID uint
|
||||
}
|
||||
Nested struct {
|
||||
ID uint
|
||||
Preloads []*Preload
|
||||
Join Join
|
||||
ValueID uint
|
||||
}
|
||||
Value struct {
|
||||
ID uint
|
||||
Name string
|
||||
Nested Nested
|
||||
}
|
||||
)
|
||||
|
||||
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
|
||||
value := Value{
|
||||
Name: "value",
|
||||
Nested: Nested{
|
||||
Preloads: []*Preload{
|
||||
{Value: "p1"}, {Value: "p2"},
|
||||
},
|
||||
Join: Join{Value: "j1"},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&value).Error; err != nil {
|
||||
t.Errorf("failed to create value, got err: %v", err)
|
||||
}
|
||||
|
||||
var find1 Value
|
||||
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find1, value)
|
||||
|
||||
var find2 Value
|
||||
// Joins will automatically add Nested queries.
|
||||
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find2, value)
|
||||
}
|
||||
|
||||
func TestEmbedPreload(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
|
@ -1118,12 +1118,12 @@ func TestSearchWithStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,9 @@ func TestComplexScopes(t *testing.T) {
|
||||
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
|
||||
func(d *gorm.DB) *gorm.DB {
|
||||
return d.Where(DB.Or("b = 2").Or("c = 3"))
|
||||
},
|
||||
).Find(&Language{})
|
||||
},
|
||||
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
|
||||
@ -93,7 +95,9 @@ func TestComplexScopes(t *testing.T) {
|
||||
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Where("z = 0").Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||
func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
|
||||
func(d *gorm.DB) *gorm.DB {
|
||||
return d.Or(DB.Where("b = 2").Or("c = 3"))
|
||||
},
|
||||
).Find(&Language{})
|
||||
},
|
||||
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
|
||||
@ -104,7 +108,7 @@ func TestComplexScopes(t *testing.T) {
|
||||
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
|
||||
func(d *gorm.DB) *gorm.DB {
|
||||
return d.
|
||||
Or(d.Scopes(
|
||||
Or(DB.Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
|
||||
)).
|
||||
|
@ -388,7 +388,7 @@ func TestToSQL(t *testing.T) {
|
||||
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{})
|
||||
})
|
||||
assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql)
|
||||
assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql)
|
||||
|
||||
// last and unscoped
|
||||
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
|
@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
|
||||
func RunMigrations() {
|
||||
var err error
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||
|
||||
|
@ -882,3 +882,52 @@ func TestSaveWithHooks(t *testing.T) {
|
||||
t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// only postgres, sqlserver, sqlite support update from
|
||||
func TestUpdateFrom(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
users := []*User{
|
||||
GetUser("update-from-1", Config{Account: true}),
|
||||
GetUser("update-from-2", Config{Account: true}),
|
||||
GetUser("update-from-3", Config{}),
|
||||
}
|
||||
|
||||
if err := DB.Create(&users).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
} else if users[0].ID == 0 {
|
||||
t.Fatalf("user's primary value should not zero, %v", users[0].ID)
|
||||
} else if users[0].UpdatedAt.IsZero() {
|
||||
t.Fatalf("user's updated at should not zero, %v", users[0].UpdatedAt)
|
||||
}
|
||||
|
||||
if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number = ? AND accounts.deleted_at IS NULL", users[0].Account.Number).Update("name", "franco").RowsAffected; rowsAffected != 1 {
|
||||
t.Errorf("should only update one record, but got %v", rowsAffected)
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err != nil {
|
||||
t.Errorf("errors happened when query before user: %v", err)
|
||||
} else if result.UpdatedAt.UnixNano() == users[0].UpdatedAt.UnixNano() {
|
||||
t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, users[0].UpdatedAt)
|
||||
} else if result.Name != "franco" {
|
||||
t.Errorf("user's name should be updated")
|
||||
}
|
||||
|
||||
if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number IN ? AND accounts.deleted_at IS NULL", []string{users[0].Account.Number, users[1].Account.Number}).Update("name", gorm.Expr("accounts.number")).RowsAffected; rowsAffected != 2 {
|
||||
t.Errorf("should update two records, but got %v", rowsAffected)
|
||||
}
|
||||
|
||||
var results []User
|
||||
if err := DB.Preload("Account").Find(&results, []uint{users[0].ID, users[1].ID}).Error; err != nil {
|
||||
t.Errorf("Not error should happen when finding users, but got %v", err)
|
||||
}
|
||||
|
||||
for _, user := range results {
|
||||
if user.Name != user.Account.Number {
|
||||
t.Errorf("user's name should be equal to the account's number %v, but got %v", user.Account.Number, user.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,8 @@ type User struct {
|
||||
Account Account
|
||||
Pets []*Pet
|
||||
NamedPet *Pet
|
||||
Toys []Toy `gorm:"polymorphic:Owner"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner"`
|
||||
Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"`
|
||||
CompanyID *int
|
||||
Company Company
|
||||
ManagerID *uint
|
||||
@ -51,6 +52,13 @@ type Toy struct {
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Tools struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
CustomID string
|
||||
Type string
|
||||
}
|
||||
|
||||
type Company struct {
|
||||
ID int
|
||||
Name string
|
||||
|
@ -35,7 +35,8 @@ func FileWithLineNum() string {
|
||||
// the second caller usually from gorm internal, so set i start from 2
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) {
|
||||
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) &&
|
||||
!strings.HasSuffix(file, ".gen.go") {
|
||||
return file + ":" + strconv.FormatInt(int64(line), 10)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user