diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index aa1812d4..bc4487ae 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index c3c92beb..f9f51aa0 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af8d3636..a9aff43a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b97da3f4..367f4ccd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -86,7 +86,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/.golangci.yml b/.golangci.yml index 16903ed6..b88bf672 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,3 +9,12 @@ linters: - prealloc - unconvert - unparam + - goimports + - whitespace + +linters-settings: + whitespace: + multi-func: true + goimports: + local-prefixes: gorm.io/gorm + diff --git a/README.md b/README.md index 312a3a59..68fa6603 100644 --- a/README.md +++ b/README.md @@ -30,12 +30,18 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) -* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) ## Contributing [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) +## Contributors + +Thank you for contributing to the GORM framework! + +[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors) + ## License © Jinzhu, 2013~time.Now diff --git a/callbacks/associations.go b/callbacks/associations.go index 00e00fcc..9d7c1412 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + if isPtr { elems = reflect.Append(elems, elem) } else { @@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + distinctElems = reflect.Append(distinctElems, elem) } diff --git a/callbacks/query.go b/callbacks/query.go index a0e75a7e..97fe8a49 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -117,12 +117,20 @@ func BuildQuerySQL(db *gorm.DB) { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } } exprs := make([]clause.Expression, len(relation.References)) diff --git a/callbacks/update.go b/callbacks/update.go index 48c61bf4..b596df9a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else if _, ok := db.Statement.Clauses["SET"]; !ok { - return + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } } db.Statement.Build(db.Statement.BuildClauses...) diff --git a/chainable_api.go b/chainable_api.go index 65c56362..8a92a9e3 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -10,10 +10,11 @@ import ( ) // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` +// db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value @@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) { } // Clauses Add clauses +// +// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more +// advanced techniques like specifying lock strength and optimizer hints. See the +// [docs] for more depth. +// +// // add a simple limit clause +// db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) +// // tell the optimizer to use the `idx_user_name` index +// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) +// // specify the lock strength to UPDATE +// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) +// +// [docs]: https://gorm.io/docs/sql_builder.html#Clauses func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { tx = db.getInstance() var whereConds []interface{} @@ -44,6 +58,9 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) // Table specify the table you would like to run db operations +// +// // Get a user +// db.Table("users").take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { @@ -65,6 +82,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } // Distinct specify distinct fields that you want querying +// +// // Select distinct names of users +// db.Distinct("name").Find(&results) +// // Select distinct name/age pairs from users +// db.Distinct("name", "age").Find(&results) func (db *DB) Distinct(args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Distinct = true @@ -75,6 +97,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { } // Select specify fields that you want when querying, creating, updating +// +// Use Select when you only want a subset of the fields. By default, GORM will select all fields. +// Select accepts both string arguments and arrays. +// +// // Select name and age of user using multiple arguments +// db.Select("name", "age").Find(&users) +// // Select name and age of user using an array +// db.Select([]string{"name", "age"}).Find(&users) func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() @@ -152,6 +182,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) { } // Where add conditions +// +// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. +// +// // Find the first user with name jinzhu +// db.Where("name = ?", "jinzhu").First(&user) +// // Find the first user with name jinzhu and age 20 +// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) +// // Find the first user with name jinzhu and age not equal to 20 +// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) +// +// [docs]: https://gorm.io/docs/query.html#Conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -161,6 +202,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { } // Not add NOT conditions +// +// Not works similarly to where, and has the same syntax. +// +// // Find the first user with name not equal to jinzhu +// db.Not("name = ?", "jinzhu").First(&user) func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -170,6 +216,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { } // Or add OR conditions +// +// Or is used to chain together queries with an OR. +// +// // Find the first user with name equal to jinzhu or john +// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -179,9 +230,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { } // Joins specify Joins conditions -// db.Joins("Account").Find(&user) -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { return joins(db, clause.LeftJoin, query, args...) } @@ -197,10 +249,12 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) if len(args) == 1 { if db, ok := args[0].(*DB); ok { + j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits} if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where, JoinType: joinType}) - return + j.On = &where } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return } } @@ -209,6 +263,9 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) } // Group specify the group method on the find +// +// // Select the sum age of users with given names +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() @@ -220,6 +277,9 @@ func (db *DB) Group(name string) (tx *DB) { } // Having specify HAVING conditions for GROUP BY +// +// // Select the sum age of users with name jinzhu +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ @@ -228,9 +288,10 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { return } -// Order specify order when retrieve records from database -// db.Order("name DESC") -// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +// Order specify order when retrieving records from database +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() @@ -252,13 +313,27 @@ func (db *DB) Order(value interface{}) (tx *DB) { } // Limit specify the number of records to be retrieved +// +// Limit conditions can be cancelled by using `Limit(-1)`. +// +// // retrieve 3 users +// db.Limit(3).Find(&users) +// // retrieve 3 users into users1, and all users into users2 +// db.Limit(3).Find(&users1).Limit(-1).Find(&users2) func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Limit{Limit: limit}) + tx.Statement.AddClause(clause.Limit{Limit: &limit}) return } // Offset specify the number of records to skip before starting to return the records +// +// Offset conditions can be cancelled by using `Offset(-1)`. +// +// // select the third user +// db.Offset(2).First(&user) +// // select the first user by cancelling an earlier chained offset +// db.Offset(5).Offset(-1).First(&user) func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) @@ -266,17 +341,18 @@ func (db *DB) Offset(offset int) (tx *DB) { } // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { tx = db.getInstance() tx.Statement.scopes = append(tx.Statement.scopes, funcs...) @@ -284,7 +360,9 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// // get all users, and preload all non-cancelled orders +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Preloads == nil { @@ -294,12 +372,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { return } +// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Attrs only adds attributes if the record is not found. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign an email if the record is not found, otherwise ignore provided email +// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.attrs = attrs return } +// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that +// records will be updated even if they are found. +// +// // assign an email regardless of if the record is not found +// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.assigns = attrs diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index e08677ac..34d5df41 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + limit10 := 10 for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ @@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) { clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, - clause.Limit{Limit: 10, Offset: 20}, + clause.Limit{Limit: &limit10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } diff --git a/clause/joins.go b/clause/joins.go index f3e373f2..879892be 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -9,7 +9,7 @@ const ( RightJoin JoinType = "RIGHT" ) -// Join join clause for from +// Join clause for from type Join struct { Type JoinType Table Table diff --git a/clause/joins_test.go b/clause/joins_test.go new file mode 100644 index 00000000..f1f20ec3 --- /dev/null +++ b/clause/joins_test.go @@ -0,0 +1,101 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestJoin(t *testing.T) { + results := []struct { + name string + join clause.Join + sql string + }{ + { + name: "LEFT JOIN", + join: clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "RIGHT JOIN", + join: clause.Join{ + Type: clause.RightJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "INNER JOIN", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "CROSS JOIN", + join: clause.Join{ + Type: clause.CrossJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "USING", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + { + name: "Expression", + join: clause.Join{ + // Invalid + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + // Valid + Expression: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + } + for _, result := range results { + t.Run(result.name, func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + result.join.Build(stmt) + if result.sql != stmt.SQL.String() { + t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) + } + }) + } +} diff --git a/clause/limit.go b/clause/limit.go index 184f6025..3ede7385 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -4,7 +4,7 @@ import "strconv" // Limit limit clause type Limit struct { - Limit int + Limit *int Offset int } @@ -15,12 +15,12 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(limit.Limit)) + builder.WriteString(strconv.Itoa(*limit.Limit)) } if limit.Offset > 0 { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") @@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit != 0 { + if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { limit.Limit = v.Limit } diff --git a/clause/limit_test.go b/clause/limit_test.go index c26294aa..79065ab6 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -8,6 +8,10 @@ import ( ) func TestLimit(t *testing.T) { + limit0 := 0 + limit10 := 10 + limit50 := 50 + limitNeg10 := -10 results := []struct { Clauses []clause.Interface Result string @@ -15,11 +19,15 @@ func TestLimit(t *testing.T) { }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ - Limit: 10, + Limit: &limit10, Offset: 20, }}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, "SELECT * FROM `users` OFFSET 20", nil, @@ -29,23 +37,23 @@ func TestLimit(t *testing.T) { "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, "SELECT * FROM `users` LIMIT 10", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, }, } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 309c5fcd..032bf4a1 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -16,27 +16,27 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { - if len(onConflict.Columns) > 0 { - builder.WriteByte('(') - for idx, column := range onConflict.Columns { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(column) - } - builder.WriteString(`) `) - } - - if len(onConflict.TargetWhere.Exprs) > 0 { - builder.WriteString(" WHERE ") - onConflict.TargetWhere.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } } if onConflict.DoNothing { diff --git a/finisher_api.go b/finisher_api.go index af9afb63..39d9fca3 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -13,7 +13,7 @@ import ( "gorm.io/gorm/utils" ) -// Create insert the value into database +// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) @@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { return tx.callbacks.Create().Execute(tx) } -// CreateInBatches insert the value in batches into database +// CreateInBatches inserts value in batches of batchSize func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return } -// Save update value in database, if the value doesn't have primary key, will insert it +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { return } -// First find first record that match given conditions, order by primary key +// First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Take return a record that match given conditions, the order will depend on the database implementation +// Take finds the first record returned by the database in no specified order, matching given conditions conds func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { @@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Last find last record that match given conditions, order by primary key +// Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Find find records that match given conditions +// Find finds all records matching given conditions conds func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// FindInBatches find records in batches +// FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( tx = db.Order(clause.OrderByColumn{ @@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var totalSize int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { - totalSize = limit.Limit + if limit.Limit != nil { + totalSize = *limit.Limit + } if totalSize > 0 && batchSize > totalSize { batchSize = totalSize @@ -229,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -286,7 +292,18 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { } } -// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. +// +// FirstOrInit never modifies the database. It is often used with Assign and Attrs. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -312,62 +329,78 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } -// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. +// +// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. +// +// // assign an email if the record is not found +// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +// +// // assign email regardless of if record is found +// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// // result.RowsAffected -> 1 func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if result := queryTx.Find(dest, conds...); result.Error == nil { - if result.RowsAffected == 0 { - if c, ok := result.Statement.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok { - result.assignInterfacesToValue(where.Exprs) - } - } - // initialize with attrs, conds - if len(db.Statement.attrs) > 0 { - result.assignInterfacesToValue(db.Statement.attrs...) - } - - // initialize with attrs, conds - if len(db.Statement.assigns) > 0 { - result.assignInterfacesToValue(db.Statement.assigns...) - } - - return tx.Create(dest) - } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) - assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - assigns[column] = eq.Value - case clause.Column: - assigns[column.Name] = eq.Value - default: - } - } - } - - return tx.Model(dest).Updates(assigns) - } - } else { + result := queryTx.Find(dest, conds...) + if result.Error != nil { tx.Error = result.Error + return tx } + + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + result.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) + } + + // initialize with attrs, conds + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + } + } + } + + return tx.Model(dest).Updates(assigns) + } + return tx } -// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} return tx.callbacks.Update().Execute(tx) } -// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values @@ -388,7 +421,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -455,7 +490,7 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { + if tx.RowsAffected != 1 { *count = tx.RowsAffected } @@ -482,7 +517,7 @@ func (db *DB) Rows() (*sql.Rows, error) { return rows, tx.Error } -// Scan scan value to a struct +// Scan scans selected value to the struct dest func (db *DB) Scan(dest interface{}) (tx *DB) { config := *db.Config currentLogger, newLogger := config.Logger, logger.Recorder.New() @@ -507,9 +542,10 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Model(&users).Pluck("age", &ages) +// Pluck queries a single column from a model, returning in the slice dest. E.g.: +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { @@ -550,7 +586,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } -// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. func (db *DB) Connection(fc func(tx *DB) error) (err error) { if db.Error != nil { return db.Error @@ -572,7 +609,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true @@ -615,7 +654,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } -// Begin begins a transaction +// Begin begins a transaction with any transaction options opts func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement @@ -644,7 +683,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { return tx } -// Commit commit a transaction +// Commit commits the changes in a transaction func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) @@ -654,7 +693,7 @@ func (db *DB) Commit() *DB { return db } -// Rollback rollback a transaction +// Rollback rollbacks the changes in a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { @@ -684,7 +723,7 @@ func (db *DB) RollbackTo(name string) *DB { return db } -// Exec execute raw sql +// Exec executes raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} diff --git a/go.mod b/go.mod index 57362745..03f84379 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm -go 1.14 +go 1.16 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/gorm.go b/gorm.go index c852e60c..89488b75 100644 --- a/gorm.go +++ b/gorm.go @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]Stmt{}, + Stmts: make(map[string]*Stmt), Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } @@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } } txConfig.ConnPool = tx.Statement.ConnPool txConfig.PrepareStmt = true @@ -413,7 +421,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation, ok := modelSchema.Relationships.Relations[field] isRelation := ok && relation.JoinTable != nil if !isRelation { - return fmt.Errorf("failed to found relation: %s", field) + return fmt.Errorf("failed to find relation: %s", field) } for _, ref := range relation.References { diff --git a/logger/logger.go b/logger/logger.go index 2ffd28d5..ce088561 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "os" "time" @@ -68,8 +68,8 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to ioutil.Discard - Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Discard Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, diff --git a/migrator.go b/migrator.go index 34e888f2..882fc4cc 100644 --- a/migrator.go +++ b/migrator.go @@ -68,6 +68,7 @@ type Migrator interface { // Database CurrentDatabase() string FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string // Tables CreateTable(dst ...interface{}) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 87ac7745..eafe7bb2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -8,14 +8,16 @@ import ( "reflect" "regexp" "strings" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) var ( - regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) ) // Migrator m struct @@ -30,6 +32,16 @@ type Config struct { gorm.Dialector } +type printSQLLogger struct { + logger.Interface +} + +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string @@ -92,14 +104,19 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) - if !tx.Migrator().HasTable(value) { - if err := tx.Migrator().CreateTable(value); err != nil { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, err := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err } @@ -117,10 +134,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, dbName); err != nil { + if err := execTx.Migrator().AddColumn(value, dbName); err != nil { return err } - } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { // found, smart migrate return err } @@ -129,25 +146,25 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } - } - - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } } + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + for _, idx := range stmt.Schema.ParseIndexes() { - if !tx.Migrator().HasIndex(value, idx.Name) { - if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } } @@ -406,32 +423,49 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - alterColumn := false + var ( + alterColumn, isSameType bool + ) - // check type - if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { - alterColumn = true - } + if !field.PrimaryKey { + // check type + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } - // check size - if length, ok := columnType.Length(); length != int64(field.Size) { - if length > 0 && field.Size > 0 { - alterColumn = true - } else { - // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here - matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if !field.PrimaryKey && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + if !isSameType { alterColumn = true } } } - // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } } } @@ -453,17 +487,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check default value if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && !strings.EqualFold(field.DefaultValue, "NULL") dv, dvNotNull := columnType.DefaultValue() - if dvNotNull && field.DefaultValueInterface == nil { + if dvNotNull && !currentDefaultNotNull { // defalut value -> null alterColumn = true - } else if !dvNotNull && field.DefaultValueInterface != nil { + } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if dv != field.DefaultValue { + } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || + (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { // default value not equal // not both null - if !(field.DefaultValueInterface == nil && !dvNotNull) { + if currentDefaultNotNull || dvNotNull { alterColumn = true } } @@ -478,7 +514,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.Name) + return m.DB.Migrator().AlterColumn(value, field.DBName) } return nil @@ -863,3 +899,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { return nil, errors.New("not support") } + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} diff --git a/prepare_stmt.go b/prepare_stmt.go index b062b0d6..e09fe814 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,10 +9,12 @@ import ( type Stmt struct { *sql.Stmt Transaction bool + prepared chan struct{} + prepareErr error } type PreparedStmtDB struct { - Stmts map[string]Stmt + Stmts map[string]*Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool @@ -42,31 +44,72 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Mux.Lock() + defer db.Mux.Unlock() + + for _, stmt := range db.Stmts { + go stmt.Close() + } + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = make(map[string]*Stmt) +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - return stmt, nil + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } db.Mux.RUnlock() db.Mux.Lock() - defer db.Mux.Unlock() - // double check if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - return stmt, nil - } else if ok { - go stmt.Close() + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } + // cache preparing stmt first + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext + // suppose the maxopen is 1, g1 is creating record and g2 is querying record. + // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) - if err == nil { - db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} - db.PreparedSQL = append(db.PreparedSQL, query) + if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() + return Stmt{}, err } - return db.Stmts[query], err + db.Mux.Lock() + cacheStmt.Stmt = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, nil } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { diff --git a/scan.go b/scan.go index 6250fb57..12a77862 100644 --- a/scan.go +++ b/scan.go @@ -65,31 +65,32 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - - joinedSchemaMap := make(map[*schema.Field]interface{}, 0) + joinedSchemaMap := make(map[*schema.Field]interface{}) for idx, field := range fields { - if field != nil { - if len(joinFields) == 0 || joinFields[idx][0] == nil { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil - } - } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) - } - - // release data to pool - field.NewValuePool.Put(values[idx]) + if field == nil { + continue } + + if len(joinFields) == 0 || joinFields[idx][0] == nil { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } + } + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) + } + + // release data to pool + field.NewValuePool.Put(values[idx]) } } @@ -161,11 +162,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - selectedColumnsMap = make(map[string]int, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -198,26 +198,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { - schFieldsCount := len(sch.Fields) + matchedFieldCount := make(map[string]int, len(columns)) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - if curIndex, ok := selectedColumnsMap[column]; ok { - fields[idx] = field // handle duplicate fields - offset := curIndex + 1 - // handle sch inconsistent with database - // like Raw(`...`).Scan - if schFieldsCount > offset { - for fieldIndex, selectField := range sch.Fields[offset:] { - if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { + if selectField.DBName == column && selectField.Readable { + if count == 0 { + matchedFieldCount[column]++ fields[idx] = selectField break } + count-- } } } else { - fields[idx] = field - selectedColumnsMap[column] = idx + matchedFieldCount[column] = 1 } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { @@ -241,12 +239,20 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var elem reflect.Value - recyclableStruct := reflect.New(reflectValueType) + var ( + elem reflect.Value + isArrayKind = reflectValue.Kind() == reflect.Array + ) if !update || reflectValue.Len() == 0 { update = false - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else if !isArrayKind { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } for initialized || rows.Next() { @@ -267,20 +273,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } } } else { - if isPtr && db.RowsAffected > 0 { - elem = reflect.New(reflectValueType) - } else { - elem = recyclableStruct - } + elem = reflect.New(reflectValueType) } db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + reflectValue = reflect.Append(reflectValue, elem) } } } diff --git a/schema/relationship.go b/schema/relationship.go index 0aa33e51..9436f283 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} - ownFieldsMap = map[string]bool{} // fix self join many2many + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) @@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinFieldName = strings.Title(joinForeignKeys[idx]) } - ownFieldsMap[joinFieldName] = true + ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, @@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, relField := range refForeignFields { joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name - if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) - } if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -254,14 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - fieldsMap[joinFieldName] = relField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: joinFieldName, - PkgPath: relField.StructField.PkgPath, - Type: relField.StructField.Type, - Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), - "column", "autoincrement", "index", "unique", "uniqueindex"), - }) + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } } joinTableFields = append(joinTableFields, reflect.StructField{ @@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPrimaryField { + if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: of, ForeignKey: f, }) - } else { + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, ForeignKey: f, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPrimaryField, - }) } } } @@ -391,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema case guessHas: case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - if f := foreignSchema.LookUpField(foreignKey); f != nil { - foreignFields = append(foreignFields, f) - } else { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { reguessOrErr() return } + foreignFields = append(foreignFields, f) } } else { - var primaryFields []*Field var primarySchemaName = primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name @@ -454,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } - if len(foreignFields) == 0 { + switch { + case len(foreignFields) == 0: reguessOrErr() return - } else if len(relation.primaryKeys) > 0 { + case len(relation.primaryKeys) > 0: for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { @@ -471,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu return } } - } else if len(primaryFields) == 0 { + case len(primaryFields) == 0: if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 6fffbfcb..85c45589 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -10,7 +10,7 @@ import ( func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { - t.Errorf("Failed to parse schema") + t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) @@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) { }) } +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema.go b/schema/schema.go index 3791237d..21e71c21 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -71,6 +71,10 @@ type Tabler interface { TableName() string } +type TablerWithNamer interface { + TableName(Namer) string +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } @@ -223,7 +230,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } for _, field := range schema.Fields { - if field.HasDefaultValue && field.DefaultValueInterface == nil { + if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } } @@ -239,6 +246,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.HasDefaultValue = true field.AutoIncrement = true } + case String: + if _, ok := field.TagSettings["PRIMARYKEY"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + } } } diff --git a/schema/serializer.go b/schema/serializer.go index 00a4f85f..9a6aa4fc 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, // Value implements serializer interface func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) + if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } + return nil, err + } return string(result), err } diff --git a/statement.go b/statement.go index aa3937ae..bc12eb76 100644 --- a/statement.go +++ b/statement.go @@ -49,10 +49,12 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where - JoinType clause.JoinType + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + JoinType clause.JoinType } // StatementModifier statement modifier interface @@ -180,6 +182,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: @@ -541,8 +547,9 @@ func (stmt *Statement) clone() *Statement { } // SetColumn set column's value -// stmt.SetColumn("Name", "jinzhu") // Hooks Method -// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value @@ -651,7 +658,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:[\W]?([A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/statement_test.go b/statement_test.go index a537c7be..761daf37 100644 --- a/statement_test.go +++ b/statement_test.go @@ -37,18 +37,18 @@ func TestWhereCloneCorruption(t *testing.T) { func TestNameMatcher(t *testing.T) { for k, v := range map[string][]string{ - "table.name": []string{"table", "name"}, - "`table`.`name`": []string{"table", "name"}, - "'table'.'name'": []string{"table", "name"}, - "'table'.name": []string{"table", "name"}, - "table1.name_23": []string{"table1", "name_23"}, - "`table_1`.`name23`": []string{"table_1", "name23"}, - "'table23'.'name_1'": []string{"table23", "name_1"}, - "'table23'.name1": []string{"table23", "name1"}, - "'name1'": []string{"", "name1"}, - "`name_1`": []string{"", "name_1"}, - "`Name_1`": []string{"", "Name_1"}, - "`Table`.`nAme`": []string{"Table", "nAme"}, + "table.name": {"table", "name"}, + "`table`.`name`": {"table", "name"}, + "'table'.'name'": {"table", "name"}, + "'table'.name": {"table", "name"}, + "table1.name_23": {"table1", "name_23"}, + "`table_1`.`name23`": {"table_1", "name23"}, + "'table23'.'name_1'": {"table23", "name_1"}, + "'table23'.name1": {"table23", "name1"}, + "'name1'": {"", "name1"}, + "`name_1`": {"", "name_1"}, + "`Name_1`": {"", "Name_1"}, + "`Table`.`nAme`": {"Table", "nAme"}, } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index f74799ce..a1f014d9 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -224,3 +225,28 @@ func TestBelongsToAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } + +func TestBelongsToDefaultValue(t *testing.T) { + type Org struct { + ID string + } + type BelongsToUser struct { + OrgID string + Org Org `gorm:"default:NULL"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Config.DisableForeignKeyConstraintWhenMigrating = true + AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false) + + tx.Migrator().DropTable(&BelongsToUser{}, &Org{}) + tx.AutoMigrate(&BelongsToUser{}, &Org{}) + + user := &BelongsToUser{ + Org: Org{ + ID: "BelongsToUser_Org_1", + }, + } + err := DB.Create(&user).Error + AssertEqual(t, err, nil) +} diff --git a/tests/associations_test.go b/tests/associations_test.go index 42b32afc..4c9076da 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) { AssertEqual(t, len(orgs), 0) } } + +type AssociationEmptyUser struct { + ID uint + Name string + Pets []AssociationEmptyPet +} + +type AssociationEmptyPet struct { + AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` + Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` +} + +func TestAssociationEmptyPrimaryKey(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + + id := uint(100) + user := AssociationEmptyUser{ + ID: id, + Name: "jinzhu", + Pets: []AssociationEmptyPet{ + {AssociationEmptyUserID: &id, Name: "bar"}, + {AssociationEmptyUserID: &id, Name: "foo"}, + }, + } + + err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error + if err != nil { + t.Fatalf("Failed to create, got error: %v", err) + } + + var result AssociationEmptyUser + err = DB.Preload("Pets").First(&result, &id).Error + if err != nil { + t.Fatalf("Failed to find, got error: %v", err) + } + + AssertEqual(t, result, user) +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 2bf9496b..4479da4c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) { for idx, data := range datas { db, err := gorm.Open(nil, nil) + if err != nil { + t.Fatal(err) + } callbacks := db.Callback() for _, c := range data.callbacks { diff --git a/tests/connpool_test.go b/tests/connpool_test.go index fbae2294..42e029bc 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } diff --git a/tests/count_test.go b/tests/count_test.go index b71e3de5..2199dc6d 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -141,7 +141,7 @@ func TestCount(t *testing.T) { } DB.Create(sameUsers) - if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index e309d06c..ae69baca 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -36,7 +36,7 @@ func TestEmbeddedStruct(t *testing.T) { type EngadgetPost struct { BasePost BasePost `gorm:"Embedded"` - Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct ImageUrl string } @@ -74,13 +74,27 @@ func TestEmbeddedStruct(t *testing.T) { t.Errorf("embedded struct's value should be scanned correctly") } - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}}) var egNews EngadgetPost if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if egNews.BasePost.Title != "engadget_news" { t.Errorf("embedded struct's value should be scanned correctly") } + + var egPosts []EngadgetPost + if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil { + t.Fatalf("no error should happen when query with embedded struct, but got %v", err) + } + expectAuthors := []string{"Edward", "George"} + for i, post := range egPosts { + t.Log(i, post.Author) + if want := expectAuthors[i]; post.Author.Name != want { + t.Errorf("expected author %s got %s", want, post.Author.Name) + } + } + } func TestEmbeddedPointerTypeStruct(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod index eb8f336d..23fc2cad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,19 +1,18 @@ module gorm.io/gorm/tests -go 1.14 +go 1.16 require ( - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.6 - github.com/mattn/go-sqlite3 v1.14.14 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - gorm.io/driver/mysql v1.3.5 - gorm.io/driver/postgres v1.3.8 - gorm.io/driver/sqlite v1.3.6 - gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.8 + github.com/lib/pq v1.10.7 + github.com/mattn/go-sqlite3 v1.14.16 // indirect + golang.org/x/crypto v0.3.0 // indirect + gorm.io/driver/mysql v1.4.4 + gorm.io/driver/postgres v1.4.5 + gorm.io/driver/sqlite v1.4.3 + gorm.io/driver/sqlserver v1.4.1 + gorm.io/gorm v1.24.2 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index 86ab6bc0..057ad333 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -251,3 +251,77 @@ func TestInnerJoins(t *testing.T) { AssertEqual(t, err, nil) CheckUser(t, user3, user) } + +func TestJoinWithSameColumnName(t *testing.T) { + user := GetUser("TestJoinWithSameColumnName", Config{ + Languages: 1, + Pets: 1, + }) + DB.Create(user) + type UserSpeak struct { + UserID uint + LanguageCode string + } + type Result struct { + User + UserSpeak + Language + Pet + } + + results := make([]Result, 0, 1) + DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). + Joins("JOIN languages ON languages.code = user_speaks.language_code"). + Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) + + if len(results) == 0 { + t.Fatalf("no record find") + } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { + t.Fatalf("wrong user id in pet") + } else if results[0].Pet.Name != user.Pets[0].Name { + t.Fatalf("wrong pet name") + } +} + +func TestJoinArgsWithDB(t *testing.T) { + user := *GetUser("joins-args-db", Config{Pets: 2}) + DB.Save(&user) + + // test where + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") + + // test where and omit + onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) + AssertEqual(t, user2.NamedPet.Name, "") + + // test where and select + onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") + var user3 User + if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user3.NamedPet.ID, 0) + AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") + + // test select + onQuery4 := DB.Select("ID") + var user4 User + if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + if user4.NamedPet.ID == 0 { + t.Fatal("Pet ID can not be empty") + } + AssertEqual(t, user4.NamedPet.Name, "") +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0b5bc5eb..5f7e0749 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,15 +1,19 @@ package tests_test import ( + "context" "fmt" "math/rand" + "os" "reflect" "strings" "testing" "time" "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -72,6 +76,44 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } + +} + +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Smallint int8 + + type MigrateInt struct { + Int8 Smallint + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } } func TestAutoMigrateSelfReferential(t *testing.T) { @@ -400,7 +442,7 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) @@ -715,6 +757,32 @@ func TestPrimarykeyID(t *testing.T) { } } +func TestCurrentTimestamp(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + type CurrentTimestampTest struct { + ID string `gorm:"primary_key"` + TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"` + } + var err error + err = DB.Migrator().DropTable(&CurrentTimestampTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) +} + func TestUniqueColumn(t *testing.T) { if DB.Dialector.Name() != "mysql" { return @@ -850,7 +918,7 @@ func findColumnType(dest interface{}, columnName string) ( return } -func TestInvalidCachedPlan(t *testing.T) { +func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } @@ -885,6 +953,101 @@ func TestInvalidCachedPlan(t *testing.T) { } } +func TestInvalidCachedPlanPrepareStmt(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) + if err != nil { + t.Errorf("Open err:%v", err) + } + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + type Object1 struct { + ID uint + } + type Object2 struct { + ID uint + Field1 int `gorm:"type:int8"` + } + type Object3 struct { + ID uint + Field1 int `gorm:"type:int4"` + } + type Object4 struct { + ID uint + Field2 int + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = db.Table("objects").Create(&Object1{}).Error + if err != nil { + t.Errorf("create err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object2{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AlterColumn + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object3{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object4{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } +} + func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { type DiffType struct { ID uint @@ -959,3 +1122,41 @@ func TestMigrateArrayTypeModel(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } + +func TestMigrateSameEmbeddedFieldName(t *testing.T) { + type UserStat struct { + GroundDestroyCount int + } + + type GameUser struct { + gorm.Model + StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` + } + + type UserStat1 struct { + GroundDestroyCount string + } + + type GroundRate struct { + GroundDestroyCount int + } + + type GameUser1 struct { + gorm.Model + StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` + GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` + } + + DB.Migrator().DropTable(&GameUser{}) + err := DB.AutoMigrate(&GameUser{}) + AssertEqual(t, nil, err) + + err = DB.Table("game_users").AutoMigrate(&GameUser1{}) + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + AssertEqual(t, nil, err) +} diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 66b988c3..44cac6bf 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -7,8 +7,60 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" ) +func TestPostgresReturningIDWhichHasStringType(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Yasuo struct { + ID string `gorm:"default:gen_random_uuid()"` + Name string + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Yasuo{}) + + if err := DB.AutoMigrate(&Yasuo{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + yasuo := Yasuo{Name: "jinzhu"} + if err := DB.Create(&yasuo).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + if yasuo.ID == "" { + t.Fatal("should be able to has ID, but got zero value") + } + + var result Yasuo + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + yasuo.Name = "jinzhu1" + if err := DB.Save(&yasuo).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } +} + func TestPostgres(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() @@ -60,16 +112,55 @@ func TestPostgres(t *testing.T) { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } + + DB.Migrator().DropTable("log_usage") + + if err := DB.Exec(` +CREATE TABLE public.log_usage ( + log_id bigint NOT NULL +); + +ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( + SEQUENCE NAME public.log_usage_log_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + `).Error; err != nil { + t.Fatalf("failed to create table, got error %v", err) + } + + columns, err := DB.Migrator().ColumnTypes("log_usage") + if err != nil { + t.Fatalf("failed to get columns, got error %v", err) + } + + hasLogID := false + for _, column := range columns { + if column.Name() == "log_id" { + hasLogID = true + autoIncrement, ok := column.AutoIncrement() + if !ok || !autoIncrement { + t.Fatalf("column log_id should be auto incrementment") + } + } + } + + if !hasLogID { + t.Fatalf("failed to found column log_id") + } } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } @@ -98,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +} + +type CompanyNew struct { + ID int + Name int +} + +func TestAlterColumnDataType(t *testing.T) { + DB.AutoMigrate(Company{}) + + if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { + t.Fatalf("failed to alter column from string to int, got error %v", err) + } + + DB.AutoMigrate(Company{}) +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 8730e547..64baa01b 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,8 @@ package tests_test import ( "context" + "errors" + "sync" "testing" "time" @@ -88,3 +90,107 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } tx2.Commit() } + +func TestPreparedStmtDeadlock(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + user := User{Name: "jinzhu"} + tx.Create(&user) + + var result User + tx.First(&result) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 2) + for _, stmt := range conn.Stmts { + if stmt == nil { + t.Fatalf("stmt cannot bee nil") + } + } + + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtError(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + // err prepare + tag := Tag{Locale: "zh"} + tx.Table("users").Find(&tag) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 0) + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +} + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} diff --git a/tests/query_test.go b/tests/query_test.go index 4569fe1a..fa8f09e8 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -216,6 +216,30 @@ func TestFind(t *testing.T) { } } + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) @@ -384,6 +408,13 @@ func TestFindInBatchesWithError(t *testing.T) { if totalBatch != 0 { t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) } + + if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error != gorm.ErrPrimaryKeyRequired { + t.Fatal("expected errors to have occurred, but nothing happened") + } } func TestFillSmallerStruct(t *testing.T) { diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 946536bf..a040a4db 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -18,6 +18,8 @@ type SerializerStruct struct { gorm.Model Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a9b920dc..0fbd6118 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) { t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") } - date, _ := time.Parse("2006-01-02", "2021-10-18") + date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) // find sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) { if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, + }) + }) + assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) } // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. diff --git a/tests/table_test.go b/tests/table_test.go index 0289b7b8..f538c691 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -5,6 +5,8 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests" ) @@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +type UserWithTableNamer struct { + gorm.Model + Name string +} + +func (UserWithTableNamer) TableName(namer schema.Namer) string { + return namer.TableName("user") +} + +func TestTableWithNamer(t *testing.T) { + var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "t_", + }}) + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) + }) + + if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { + t.Errorf("Table with namer, got %v", sql) + } +} diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 0ac04a04..5872da94 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) { return errors.New("the error message") }) - if err.Error() != "the error message" { + if err != nil && err.Error() != "the error message" { t.Fatalf("Transaction return error will equal the block returns error") } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f90c4518..e84dc14a 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) { } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) - if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 2990c20f..c89b944a 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "gorm.io/gorm" + "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -13,7 +14,14 @@ func (DummyDialector) Name() string { return "dummy" } -func (DummyDialector) Initialize(*gorm.DB) error { +func (DummyDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + return nil } diff --git a/utils/tests/models.go b/utils/tests/models.go index 22e8e659..ec1651a3 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -64,8 +64,8 @@ type Language struct { type Coupon struct { ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` - AmountOff uint32 `gorm:"amount_off"` - PercentOff float32 `gorm:"percent_off"` + AmountOff uint32 `gorm:"column:amount_off"` + PercentOff float32 `gorm:"column:percent_off"` } type CouponProduct struct { diff --git a/utils/utils.go b/utils/utils.go index 296917b9..e08533cd 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,8 +3,8 @@ package utils import ( "database/sql/driver" "fmt" + "path/filepath" "reflect" - "regexp" "runtime" "strconv" "strings" @@ -16,7 +16,18 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = sourceDir(file) +} + +func sourceDir(file string) string { + dir := filepath.Dir(file) + dir = filepath.Dir(dir) + + s := filepath.Dir(dir) + if filepath.Base(s) != "gorm.io" { + s = dir + } + return filepath.ToSlash(s) + "/" } // FileWithLineNum return the file name and line number of the current file diff --git a/utils/utils_test.go b/utils/utils_test.go index 27dfee16..71eef964 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,8 +1,13 @@ package utils import ( + "database/sql" + "database/sql/driver" + "errors" + "math" "strings" "testing" + "time" ) func TestIsValidDBNameChar(t *testing.T) { @@ -13,6 +18,29 @@ func TestIsValidDBNameChar(t *testing.T) { } } +func TestCheckTruth(t *testing.T) { + checkTruthTests := []struct { + v string + out bool + }{ + {"123", true}, + {"true", true}, + {"", false}, + {"false", false}, + {"False", false}, + {"FALSE", false}, + {"\u0046alse", false}, + } + + for _, test := range checkTruthTests { + t.Run(test.v, func(t *testing.T) { + if out := CheckTruth(test.v); out != test.out { + t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) + } + }) + } +} + func TestToStringKey(t *testing.T) { cases := []struct { values []interface{} @@ -29,3 +57,81 @@ func TestToStringKey(t *testing.T) { } } } + +func TestContains(t *testing.T) { + containsTests := []struct { + name string + elems []string + elem string + out bool + }{ + {"exists", []string{"1", "2", "3"}, "1", true}, + {"not exists", []string{"1", "2", "3"}, "4", false}, + } + for _, test := range containsTests { + t.Run(test.name, func(t *testing.T) { + if out := Contains(test.elems, test.elem); test.out != out { + t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) + } + }) + } +} + +type ModifyAt sql.NullTime + +// Value return a Unix time. +func (n ModifyAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time.Unix(), nil +} + +func TestAssertEqual(t *testing.T) { + now := time.Now() + assertEqualTests := []struct { + name string + src, dst interface{} + out bool + }{ + {"error equal", errors.New("1"), errors.New("1"), true}, + {"error not equal", errors.New("1"), errors.New("2"), false}, + {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, + {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + } + for _, test := range assertEqualTests { + t.Run(test.name, func(t *testing.T) { + if out := AssertEqual(test.src, test.dst); test.out != out { + t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) + } + }) + } +} + +func TestToString(t *testing.T) { + tests := []struct { + name string + in interface{} + out string + }{ + {"int", math.MaxInt64, "9223372036854775807"}, + {"int8", int8(math.MaxInt8), "127"}, + {"int16", int16(math.MaxInt16), "32767"}, + {"int32", int32(math.MaxInt32), "2147483647"}, + {"int64", int64(math.MaxInt64), "9223372036854775807"}, + {"uint", uint(math.MaxUint64), "18446744073709551615"}, + {"uint8", uint8(math.MaxUint8), "255"}, + {"uint16", uint16(math.MaxUint16), "65535"}, + {"uint32", uint32(math.MaxUint32), "4294967295"}, + {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, + {"string", "abc", "abc"}, + {"other", true, ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if out := ToString(test.in); test.out != out { + t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) + } + }) + } +} diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go new file mode 100644 index 00000000..450cbe2a --- /dev/null +++ b/utils/utils_unix_test.go @@ -0,0 +1,38 @@ +//go:build unix +// +build unix + +package utils + +import ( + "testing" +) + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/Users/name/go/pkg/mod/gorm.io/", + }, + { + file: "/go/work/proj/gorm/utils/utils.go", + want: "/go/work/proj/gorm/", + }, + { + file: "/go/work/proj/gorm_alias/utils/utils.go", + want: "/go/work/proj/gorm_alias/", + }, + { + file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go new file mode 100644 index 00000000..8b1c519d --- /dev/null +++ b/utils/utils_windows_test.go @@ -0,0 +1,35 @@ +package utils + +import ( + "testing" +) + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/Users/name/go/pkg/mod/gorm.io/`, + }, + { + file: `C:/go/work/proj/gorm/utils/utils.go`, + want: `C:/go/work/proj/gorm/`, + }, + { + file: `C:/go/work/proj/gorm_alias/utils/utils.go`, + want: `C:/go/work/proj/gorm_alias/`, + }, + { + file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +}