Merge branch 'go-gorm:master' into master

This commit is contained in:
Paras Waykole 2021-09-10 18:19:16 +05:30 committed by GitHub
commit 37c225f4c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 559 additions and 200 deletions

15
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,15 @@
---
version: 2
updates:
- package-ecosystem: gomod
directory: /
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
- package-ecosystem: gomod
directory: /tests
schedule:
interval: weekly

View File

@ -10,7 +10,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
uses: actions/stale@v4
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -10,7 +10,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
uses: actions/stale@v4
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -8,4 +8,4 @@ jobs:
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v1
uses: reviewdog/action-golangci-lint@v2

View File

@ -10,7 +10,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
uses: actions/stale@v4
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"

View File

@ -13,7 +13,7 @@ jobs:
sqlite:
strategy:
matrix:
go: ['1.16', '1.15']
go: ['1.17', '1.16', '1.15']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
@ -39,7 +39,7 @@ jobs:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.16', '1.15']
go: ['1.17', '1.16', '1.15']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -83,7 +83,7 @@ jobs:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.16', '1.15']
go: ['1.17', '1.16', '1.15']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
@ -125,7 +125,7 @@ jobs:
sqlserver:
strategy:
matrix:
go: ['1.16', '1.15']
go: ['1.17', '1.16', '1.15']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}

View File

@ -102,8 +102,8 @@ func (p *processor) Execute(db *DB) *DB {
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)

View File

@ -310,33 +310,22 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
continue
}
if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
}
}
if len(defaultUpdatingColumns) > 0 {
columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
}
return clause.OnConflict{
Columns: columns,
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
if !onConflict.UpdateAll {
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
}
} else {
onConflict.DoNothing = true
}
return clause.OnConflict{DoNothing: true}
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {

View File

@ -37,7 +37,6 @@ func Create(config *Config) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error != nil {
// maybe record logger TODO
return
}
@ -64,11 +63,9 @@ func Create(config *Config) func(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if !(db.RowsAffected > 0) {
return
}
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
@ -107,7 +104,6 @@ func Create(config *Config) func(db *gorm.DB) {
db.AddError(err)
}
}
}
}
}
@ -348,15 +344,19 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) > 1 {
if stmt.Schema != nil && len(values.Columns) >= 1 {
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
}
}
}
}
onConflict.DoUpdates = clause.AssignmentColumns(columns)

View File

@ -44,7 +44,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
columns = make([]string, 0, len(mapValues))
)
// when the length of mapValues,return directly here
// when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)

View File

@ -1,6 +1,7 @@
package callbacks
import (
"fmt"
"reflect"
"gorm.io/gorm"
@ -104,6 +105,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
@ -113,6 +115,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
}
fieldValues := make([]interface{}, len(relForeignFields))
@ -142,7 +145,8 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
fieldValues[idx], _ = field.ValueOf(elem)
}
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok {
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
@ -160,5 +164,8 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
}
} else {
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()))
}
}
}

View File

@ -147,6 +147,21 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
if join.On != nil {
onStmt := gorm.Statement{Table: tableAliasName, DB: db}
join.On.Build(&onStmt)
onSQL := onStmt.SQL.String()
vars := onStmt.Vars
for idx, v := range onStmt.Vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
joins = append(joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
@ -209,7 +224,7 @@ func Preload(db *gorm.DB) {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
} else {
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
}

View File

@ -51,7 +51,10 @@ func BeforeUpdate(db *gorm.DB) {
}
func Update(db *gorm.DB) {
if db.Error == nil {
if db.Error != nil {
return
}
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
@ -84,7 +87,6 @@ func Update(db *gorm.DB) {
}
}
}
}
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
@ -220,12 +222,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
}
default:
var updatingSchema = stmt.Schema
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
}
}
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
@ -252,6 +263,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}

View File

@ -50,15 +50,14 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
tx.Statement.Table = results[1]
return
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
return
}
} else {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name
}
return
}
@ -172,8 +171,19 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
// Joins specify Joins conditions
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(args) > 0 {
if db, ok := args[0].(*DB); ok {
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where})
}
return
}
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
return
}
@ -209,13 +219,15 @@ func (db *DB) Order(value interface{}) (tx *DB) {
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
})
default:
case string:
if v != "" {
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
Column: clause.Column{Name: v, Raw: true},
}},
})
}
}
return
}

View File

@ -173,7 +173,12 @@ func (expr NamedExpr) Build(builder Builder) {
}
if inName {
builder.AddVar(builder, namedMap[string(name)])
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
}
}
@ -205,11 +210,12 @@ func (in IN) Build(builder Builder) {
}
func (in IN) NegationBuild(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IS NOT NULL")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteQuoted(in.Column)
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
@ -217,7 +223,6 @@ func (in IN) NegationBuild(builder Builder) {
fallthrough
default:
builder.WriteQuoted(in.Column)
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')

View File

@ -60,6 +60,11 @@ func TestNamedExpr(t *testing.T) {
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name AND name2 = @@name",
Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}},
Result: "name1 = ? AND name2 = @@name",
ExpectedVars: []interface{}{"jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
@ -73,13 +78,13 @@ func TestNamedExpr(t *testing.T) {
}, {
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
@ -151,6 +156,18 @@ func TestExpression(t *testing.T) {
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`id`) = ?",
}, {
Expressions: []clause.Expression{
clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`users`.`id`) >= ?",
}}
for idx, result := range results {

View File

@ -43,3 +43,17 @@ func (s Select) MergeClause(clause *Clause) {
clause.Expression = s
}
}
// CommaExpression represents a group of expressions separated by commas.
type CommaExpression struct {
Exprs []Expression
}
func (comma CommaExpression) Build(builder Builder) {
for idx, expr := range comma.Exprs {
if idx > 0 {
_, _ = builder.WriteString(", ")
}
expr.Build(builder)
}
}

View File

@ -31,6 +31,18 @@ func TestSelect(t *testing.T) {
}, clause.From{}},
"SELECT `name` FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}},
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}},
clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}},
},
},
}, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
},
}
for idx, result := range results {

View File

@ -376,7 +376,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
db.Statement.Clauses["SELECT"] = selectClause
tx.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
@ -390,7 +390,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
@ -410,9 +410,9 @@ func (db *DB) Count(count *int64) (tx *DB) {
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(db.Statement.Clauses, "ORDER BY")
delete(tx.Statement.Clauses, "ORDER BY")
defer func() {
db.Statement.Clauses["ORDER BY"] = orderByClause
tx.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}

View File

@ -66,6 +66,7 @@ var (
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}

View File

@ -2,6 +2,7 @@ package migrator
import (
"context"
"database/sql"
"fmt"
"reflect"
"regexp"
@ -166,11 +167,13 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration {
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
}
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
@ -195,6 +198,10 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
createTableSQL += "INDEX ? ?"
if idx.Comment != "" {
createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createTableSQL += " " + idx.Option
}
@ -382,11 +389,11 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
alterColumn = true
}
}
@ -414,22 +421,31 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
return nil
}
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
columnTypes = make([]gorm.ColumnType, 0)
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err == nil {
if err != nil {
return err
}
defer rows.Close()
rawColumnTypes, err := rows.ColumnTypes()
if err == nil {
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, c)
}
}
}
return err
return nil
})
return
return columnTypes, execErr
}
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
@ -489,9 +505,10 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
}
if field := stmt.Schema.LookUpField(name); field != nil {
for _, cc := range checkConstraints {
if cc.Field == field {
return nil, &cc, stmt.Table
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return nil, &v, stmt.Table
}
}
@ -601,6 +618,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
createIndexSQL += " USING " + idx.Type
}
if idx.Comment != "" {
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createIndexSQL += " " + idx.Option
}
@ -608,7 +629,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %v", name)
return fmt.Errorf("failed to create index with name %s", name)
})
}

View File

@ -35,7 +35,7 @@ func (db *PreparedStmtDB) Close() {
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
stmt.Close()
go stmt.Close()
}
}
@ -56,7 +56,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.Unlock()
return stmt, nil
} else if ok {
stmt.Close()
go stmt.Close()
}
stmt, err := conn.PrepareContext(ctx, query)
@ -83,7 +83,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
db.Mux.Lock()
stmt.Close()
go stmt.Close()
delete(db.Stmts, query)
db.Mux.Unlock()
}
@ -97,7 +97,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
db.Mux.Lock()
stmt.Close()
go stmt.Close()
delete(db.Stmts, query)
db.Mux.Unlock()
}
@ -138,7 +138,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
}
@ -152,7 +152,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
}

View File

@ -208,6 +208,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
values[idx] = dest
} else {
values[idx] = &sql.RawBytes{}
}
@ -238,6 +240,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
}
}
default:
db.AddError(rows.Scan(dest))
}
}

View File

@ -490,21 +490,22 @@ func (field *Field) setupValuerAndSetter() {
return
} else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(value)
fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(field.FieldType.Elem()) {
if reflectValType.AssignableTo(fieldType) {
if !fieldValue.IsValid() {
fieldValue = reflect.New(field.FieldType.Elem())
fieldValue = reflect.New(fieldType)
} else if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
fieldValue.Set(reflect.New(fieldType))
}
fieldValue.Elem().Set(reflectV)
return
} else if reflectValType.ConvertibleTo(field.FieldType.Elem()) {
} else if reflectValType.ConvertibleTo(fieldType) {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
fieldValue.Set(reflect.New(fieldType))
}
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
fieldValue.Elem().Set(reflectV.Convert(fieldType))
return
}
}
@ -520,7 +521,7 @@ func (field *Field) setupValuerAndSetter() {
err = setter(value, v)
}
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
return fmt.Errorf("failed to set value %+v to field %s", v, field.Name)
}
}

View File

@ -4,6 +4,7 @@ import (
"crypto/sha1"
"encoding/hex"
"fmt"
"regexp"
"strings"
"unicode/utf8"
@ -13,6 +14,7 @@ import (
// Namer namer interface
type Namer interface {
TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string
JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string
@ -41,6 +43,16 @@ func (ns NamingStrategy) TableName(str string) string {
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string {
return ns.toDBName(column)
@ -154,3 +166,11 @@ func (ns NamingStrategy) toDBName(name string) string {
ret := buf.String()
return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1)
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
}

View File

@ -33,6 +33,26 @@ func TestToDBName(t *testing.T) {
t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key))
}
}
maps = map[string]string{
"x": "X",
"user_restrictions": "UserRestriction",
"this_is_a_test": "ThisIsATest",
"abc_and_jkl": "AbcAndJkl",
"employee_id": "EmployeeID",
"field_x": "FieldX",
"http_and_smtp": "HTTPAndSMTP",
"http_server_handler_for_url_id": "HTTPServerHandlerForURLID",
"uuid": "UUID",
"http_url": "HTTPURL",
"sha256_hash": "Sha256Hash",
"this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID",
}
for key, value := range maps {
if ns.SchemaName(key) != value {
t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key))
}
}
}
func TestNamingStrategy(t *testing.T) {

View File

@ -238,7 +238,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
for idx, relField := range refForeignFields {
joinFieldName := relation.FieldSchema.Name + relField.Name
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}

View File

@ -119,20 +119,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
if v, loaded := cacheStore.Load(modelType); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
@ -228,11 +221,25 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
}
}
}
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
@ -244,19 +251,20 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}

View File

@ -50,6 +50,7 @@ type Statement struct {
type join struct {
Name string
Conds []interface{}
On *clause.Where
}
// StatementModifier statement modifier interface
@ -129,6 +130,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case clause.Expr:
v.Build(stmt)
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:

View File

@ -22,8 +22,8 @@ func TestDelete(t *testing.T) {
}
}
if err := DB.Delete(&users[1]).Error; err != nil {
t.Errorf("errors happened when delete: %v", err)
if res := DB.Delete(&users[1]); res.Error != nil || res.RowsAffected != 1 {
t.Errorf("errors happened when delete: %v, affected: %v", res.Error, res.RowsAffected)
}
var result User

View File

@ -3,14 +3,14 @@ module gorm.io/gorm/tests
go 1.14
require (
github.com/google/uuid v1.2.0
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.6.0
gorm.io/driver/mysql v1.0.5
github.com/lib/pq v1.10.3
gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.1.0
gorm.io/driver/sqlite v1.1.4
gorm.io/driver/sqlserver v1.0.7
gorm.io/gorm v1.21.9
gorm.io/driver/sqlserver v1.0.9
gorm.io/gorm v1.21.14
)
replace gorm.io/gorm => ../

View File

@ -104,6 +104,27 @@ func TestJoinConds(t *testing.T) {
}
}
func TestJoinOn(t *testing.T) {
var user = *GetUser("joins-on", Config{Pets: 2})
DB.Save(&user)
var user1 User
onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"})
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1")
onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"})
var user2 User
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2")
}
func TestJoinsWithSelect(t *testing.T) {
type result struct {
ID uint

View File

@ -142,17 +142,36 @@ func TestSmartMigrateColumn(t *testing.T) {
}
func TestMigrateWithComment(t *testing.T) {
type UserWithComment struct {
func TestMigrateWithColumnComment(t *testing.T) {
type UserWithColumnComment struct {
gorm.Model
Name string `gorm:"size:111;index:,comment:这是一个index;comment:this is a 字段"`
Name string `gorm:"size:111;comment:this is a 字段"`
}
if err := DB.Migrator().DropTable(&UserWithComment{}); err != nil {
if err := DB.Migrator().DropTable(&UserWithColumnComment{}); err != nil {
t.Fatalf("Failed to drop table, got error %v", err)
}
if err := DB.AutoMigrate(&UserWithComment{}); err != nil {
if err := DB.AutoMigrate(&UserWithColumnComment{}); err != nil {
t.Fatalf("Failed to auto migrate, but got error %v", err)
}
}
func TestMigrateWithIndexComment(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
type UserWithIndexComment struct {
gorm.Model
Name string `gorm:"size:111;index:,comment:这是一个index"`
}
if err := DB.Migrator().DropTable(&UserWithIndexComment{}); err != nil {
t.Fatalf("Failed to drop table, got error %v", err)
}
if err := DB.AutoMigrate(&UserWithIndexComment{}); err != nil {
t.Fatalf("Failed to auto migrate, but got error %v", err)
}
}

View File

@ -436,6 +436,11 @@ func TestNot(t *testing.T) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Not(map[string]interface{}{"name": []string{}}).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
@ -842,7 +847,17 @@ func TestSearchWithEmptyChain(t *testing.T) {
func TestOrder(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true})
result := dryDB.Order("age desc, name").Find(&User{})
result := dryDB.Order("").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Order(nil).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Order("age desc, name").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String())
}

View File

@ -63,6 +63,13 @@ func TestScan(t *testing.T) {
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
t.Errorf("Scan into struct map, got %#v", results)
}
type ID uint64
var id ID
DB.Raw("select id from users where id = ?", user2.ID).Scan(&id)
if uint(id) != user2.ID {
t.Errorf("Failed to scan to customized data type")
}
}
func TestScanRows(t *testing.T) {

View File

@ -30,6 +30,26 @@ func TestTable(t *testing.T) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
r = dryDB.Table("`people`").Table("`user`").Find(&User{}).Statement
if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
r = dryDB.Table("people as p").Table("user as u").Find(&User{}).Statement
if !regexp.MustCompile("SELECT \\* FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
r = dryDB.Table("people as p").Table("user").Find(&User{}).Statement
if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
r = dryDB.Table("gorm.people").Table("user").Find(&User{}).Statement
if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement
if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())

View File

@ -1,6 +1,7 @@
package tests_test
import (
"database/sql"
"testing"
"time"
@ -85,4 +86,48 @@ func TestUpdateHasOne(t *testing.T) {
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
CheckPet(t, pet4, pet)
})
t.Run("Restriction", func(t *testing.T) {
type CustomizeAccount struct {
gorm.Model
UserID sql.NullInt64
Number string `gorm:"<-:create"`
}
type CustomizeUser struct {
gorm.Model
Name string
Account CustomizeAccount `gorm:"foreignkey:UserID"`
}
DB.Migrator().DropTable(&CustomizeUser{})
DB.Migrator().DropTable(&CustomizeAccount{})
if err := DB.AutoMigrate(&CustomizeUser{}); err != nil {
t.Fatalf("failed to migrate, got error: %v", err)
}
if err := DB.AutoMigrate(&CustomizeAccount{}); err != nil {
t.Fatalf("failed to migrate, got error: %v", err)
}
number := "number-has-one-associations"
cusUser := CustomizeUser{
Name: "update-has-one-associations",
Account: CustomizeAccount{
Number: number,
},
}
if err := DB.Create(&cusUser).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
cusUser.Account.Number += "-update"
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
var account2 CustomizeAccount
DB.Find(&account2, "user_id = ?", cusUser.ID)
AssertEqual(t, account2.Number, number)
})
}

View File

@ -69,8 +69,10 @@ func TestUpdate(t *testing.T) {
}
values := map[string]interface{}{"Active": true, "age": 5}
if err := DB.Model(user).Updates(values).Error; err != nil {
t.Errorf("errors happened when update: %v", err)
if res := DB.Model(user).Updates(values); res.Error != nil {
t.Errorf("errors happened when update: %v", res.Error)
} else if res.RowsAffected != 1 {
t.Errorf("rows affected should be 1, but got : %v", res.RowsAffected)
} else if user.Age != 5 {
t.Errorf("Age should equals to 5, but got %v", user.Age)
} else if user.Active != true {
@ -131,7 +133,10 @@ func TestUpdates(t *testing.T) {
lastUpdatedAt := users[0].UpdatedAt
// update with map
DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100})
if res := DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}); res.Error != nil || res.RowsAffected != 1 {
t.Errorf("Failed to update users")
}
if users[0].Name != "updates_01_newname" || users[0].Age != 100 {
t.Errorf("Record should be updated also with map")
}
@ -642,6 +647,40 @@ func TestSave(t *testing.T) {
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) {
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
}
user3 := *GetUser("save3", Config{})
DB.Create(&user3)
if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil {
t.Fatalf("failed to find created user")
}
user3.Name = "save3_"
if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil {
t.Fatalf("failed to save user, got %v", err)
}
var result2 User
if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID {
t.Fatalf("failed to find updated user, got %v", err)
}
if err := DB.Model(User{Model: user3.Model}).Save(&struct {
gorm.Model
Placeholder string
Name string
}{
Model: user3.Model,
Placeholder: "placeholder",
Name: "save3__",
}).Error; err != nil {
t.Fatalf("failed to update user, got %v", err)
}
var result3 User
if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID {
t.Fatalf("failed to find updated user")
}
}
func TestSaveWithPrimaryValue(t *testing.T) {

View File

@ -1,9 +1,11 @@
package tests_test
import (
"regexp"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
@ -51,6 +53,19 @@ func TestUpsert(t *testing.T) {
if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name {
t.Fatalf("failed to upsert, got name %v", result.Name)
}
if name := DB.Dialector.Name(); name != "sqlserver" {
type RestrictedLanguage struct {
Code string `gorm:"primarykey"`
Name string
Lang string `gorm:"<-:create"`
}
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
}
}
func TestUpsertSlice(t *testing.T) {

View File

@ -11,6 +11,7 @@ import (
// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table)
// He speaks many languages (many to many) and has many friends (many to many - single-table)
// His pet also has one Toy (has one - polymorphic)
// NamedPet is a reference to a Named `Pets` (has many)
type User struct {
gorm.Model
Name string
@ -18,6 +19,7 @@ type User struct {
Birthday *time.Time
Account Account
Pets []*Pet
NamedPet *Pet
Toys []Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company Company