Merge branch 'master' into patch-oracle

This commit is contained in:
Steve Fan 2020-10-03 23:15:42 +08:00 committed by GitHub
commit b2fdc25f73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 1609 additions and 242 deletions

5
.github/labels.json vendored
View File

@ -10,6 +10,11 @@
"colour": "#EDEDED",
"description": "general questions"
},
"invalid_question": {
"name": "type:invalid question",
"colour": "#CF2E1F",
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
},
"with_playground": {
"name": "type:with reproduction steps",
"colour": "#00ff00",

22
.github/workflows/invalid_question.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: "Close invalid questions issues"
on:
schedule:
- cron: "*/10 * * * *"
jobs:
stale:
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
remove-stale-when-updated: true
only-labels: "type:invalid question"

View File

@ -13,7 +13,7 @@ jobs:
uses: actions/stale@v3.0.7
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs."
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2

View File

@ -8,7 +8,6 @@ import (
"sort"
"time"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
@ -75,17 +74,22 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) {
curTime := time.Now()
stmt := db.Statement
db.RowsAffected = 0
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
}
}
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
@ -154,7 +158,7 @@ func (p *processor) compile() (err error) {
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err)
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
}
return
}
@ -177,7 +181,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
}
func (c *callback) Remove(name string) error {
logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
@ -185,7 +189,7 @@ func (c *callback) Remove(name string) error {
}
func (c *callback) Replace(name string, fn func(*DB)) error {
logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
@ -215,7 +219,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
}
names = append(names, c.name)
}

View File

@ -5,6 +5,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func SaveBeforeAssociations(db *gorm.DB) {
@ -65,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
}
if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
DoNothing: true,
}).Create(elems.Interface()).Error) == nil {
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
@ -80,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
rv = rv.Addr()
}
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
DoNothing: true,
}).Create(rv.Interface()).Error) == nil {
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
@ -144,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(elems.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
@ -167,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(f.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(f.Interface()).Error)
}
}
}
@ -229,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(elems.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
}
}
@ -297,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
}
if elems.Len() > 0 {
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
@ -310,3 +304,32 @@ func SaveAfterAssociations(db *gorm.DB) {
}
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
}
}
if len(defaultUpdatingColumns) > 0 {
var columns []clause.Column
if s.PrioritizedPrimaryField != nil {
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
} else {
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
}
}
return clause.OnConflict{
Columns: columns,
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
}
}
return clause.OnConflict{DoNothing: true}
}

View File

@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

View File

@ -1,6 +1,7 @@
package callbacks
import (
"fmt"
"reflect"
"gorm.io/gorm"
@ -87,9 +88,12 @@ func Create(config *Config) func(db *gorm.DB) {
}
case reflect.Struct:
if insertID > 0 {
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
}
}
} else {
db.AddError(err)
}
@ -252,8 +256,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
stmt.SQL.Grow(stmt.ReflectValue.Len() * 15)
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
if stmt.ReflectValue.Len() == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]

View File

@ -2,6 +2,7 @@ package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -21,6 +22,85 @@ func BeforeDelete(db *gorm.DB) {
}
}
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if restricted {
for column, v := range selectColumns {
if v {
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{}).Model(modelValue)
withoutConditions := false
if len(db.Statement.Selects) > 0 {
var selects []string
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if strings.HasPrefix(s, column+".") {
selects = append(selects, strings.TrimPrefix(s, column+"."))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions {
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
case schema.Many2Many:
var (
queryConds []clause.Expression
foreignFields []*schema.Field
relForeignKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
}
}
}
func Delete(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {

View File

@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
)
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if stmt.Schema != nil {

View File

@ -107,6 +107,26 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
}
}
}
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {

View File

@ -11,11 +11,13 @@ func RowQuery(db *gorm.DB) {
}
if !db.DryRun {
if _, ok := db.Get("rows"); ok {
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
}
}
}

View File

@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
// Order specify order when retrieve records from database
// db.Order("name DESC")
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance()

View File

@ -3,6 +3,7 @@ package clause
import (
"database/sql"
"database/sql/driver"
"go/ast"
"reflect"
)
@ -30,19 +31,23 @@ func (expr Expr) Build(builder Builder) {
)
for _, v := range []byte(expr.SQL) {
if v == '?' {
if v == '?' && len(expr.Vars) > idx {
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
@ -85,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) {
for k, v := range value {
namedMap[k] = v
}
default:
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
}
}
}
}
}
@ -94,7 +110,7 @@ func (expr NamedExpr) Build(builder Builder) {
if v == '@' && !inName {
inName = true
name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' {
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
@ -106,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) {
}
builder.WriteByte(v)
} else if v == '?' {
} else if v == '?' && len(expr.Vars) > idx {
builder.AddVar(builder, expr.Vars[idx])
idx++
} else if inName {

View File

@ -37,6 +37,11 @@ func TestExpr(t *testing.T) {
}
func TestNamedExpr(t *testing.T) {
type NamedArgument struct {
Name1 string
Name2 string
}
results := []struct {
SQL string
Result string
@ -66,6 +71,15 @@ func TestNamedExpr(t *testing.T) {
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
Result: "create table ? (? ?, ? ?)",
}}
for idx, result := range results {

View File

@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit > 0 {
if limit.Limit == 0 && v.Limit != 0 {
limit.Limit = v.Limit
} else if limit.Limit < 0 {
limit.Limit = 0
}
if limit.Offset == 0 && v.Offset > 0 {

View File

@ -1,5 +1,9 @@
package clause
import (
"strings"
)
// Where where clause
type Where struct {
Exprs []Expression
@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) {
}
}
wrapInParentheses := false
for idx, expr := range where.Exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) {
}
}
if len(where.Exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case Expr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
if wrapInParentheses {
builder.WriteString(`(`)
expr.Build(builder)
builder.WriteString(`)`)
wrapInParentheses = false
} else {
expr.Build(builder)
}
}
}
@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) {
func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
} else if len(exprs) == 1 {
return exprs[0]
}
return AndConditions{Exprs: exprs}
}
@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}

View File

@ -27,4 +27,8 @@ var (
ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
)

View File

@ -8,6 +8,7 @@ import (
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
@ -51,7 +52,8 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.callbacks.Update().Execute(tx)
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) {
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
}
@ -331,22 +333,43 @@ func (db *DB) Count(count *int64) (tx *DB) {
}
func (db *DB) Row() *sql.Row {
tx := db.getInstance()
tx := db.getInstance().InstanceSet("rows", false)
tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
}
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.Set("rows", true)
tx := db.getInstance().InstanceSet("rows", true)
tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
}
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
currentLogger, newLogger := db.Logger, logger.Recorder.New()
tx = db.getInstance()
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
tx.Logger = newLogger
if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {
defer rows.Close()
if rows.Next() {
tx.ScanRows(rows, dest)
}
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
return
}
@ -377,9 +400,14 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance()
tx.Error = tx.Statement.Parse(dest)
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest))
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
}
Scan(rows, tx, true)
return tx.Error
}

13
gorm.go
View File

@ -20,6 +20,8 @@ type Config struct {
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
@ -64,6 +66,7 @@ type Session struct {
WithConditions bool
SkipDefaultTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.Context != nil {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
@ -169,12 +176,15 @@ func (db *DB) Session(config *Session) *DB {
if config.PrepareStmt {
if v, ok := db.cacheStore.Load("preparedStmt"); ok {
tx.Statement = tx.Statement.clone()
preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
}
@ -316,6 +326,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)

View File

@ -2,6 +2,7 @@ package logger
import (
"context"
"io/ioutil"
"log"
"os"
"time"
@ -19,6 +20,7 @@ const (
Magenta = "\033[35m"
Cyan = "\033[36m"
White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m"
YellowBold = "\033[33;1m"
@ -54,29 +56,33 @@ type Interface interface {
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
}
var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
var (
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 100 * time.Millisecond,
LogLevel: Warn,
Colorful: true,
})
})
Recorder = traceRecorder{Interface: Default}
)
func New(writer Writer, config Config) Interface {
var (
infoStr = "%s\n[info] "
warnStr = "%s\n[warn] "
errStr = "%s\n[error] "
traceStr = "%s\n[%v] [rows:%d] %s"
traceWarnStr = "%s\n[%v] [rows:%d] %s"
traceErrStr = "%s %s\n[%v] [rows:%d] %s"
traceStr = "%s\n[%.3fms] [rows:%v] %s"
traceWarnStr = "%s\n[%.3fms] [rows:%v] %s"
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
)
if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
}
return &logger{
@ -133,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
switch {
case err != nil && l.LogLevel >= Error:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel >= Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
}
}
}
type traceRecorder struct {
Interface
BeginAt time.Time
SQL string
RowsAffected int64
Err error
}
func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface}
}
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}

View File

@ -9,6 +9,8 @@ import (
"strings"
"time"
"unicode"
"gorm.io/gorm/utils"
)
func isPrintable(s []byte) bool {
@ -24,19 +26,38 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int)
var vars = make([]interface{}, len(avars))
copy(vars, avars)
var vars = make([]string, len(avars))
convertParams = func(v interface{}, idx int) {
switch v := v.(type) {
case bool:
vars[idx] = fmt.Sprint(v)
vars[idx] = strconv.FormatBool(v)
case time.Time:
if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
}
case *time.Time:
if v != nil {
if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
}
} else {
vars[idx] = "NULL"
}
case fmt.Stringer:
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
case driver.Valuer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
r, _ := v.Value()
convertParams(r, idx)
} else {
vars[idx] = "NULL"
}
case []byte:
if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
@ -44,7 +65,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
vars[idx] = escaper + "<binary>" + escaper
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = fmt.Sprintf("%d", v)
vars[idx] = utils.ToString(v)
case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v)
case string:
@ -70,18 +91,30 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
}
}
for idx, v := range vars {
for idx, v := range avars {
convertParams(v, idx)
}
if numericPlaceholder == nil {
for _, v := range vars {
sql = strings.Replace(sql, "?", v.(string), 1)
var idx int
var newSQL strings.Builder
for _, v := range []byte(sql) {
if v == '?' {
if len(vars) > idx {
newSQL.WriteString(vars[idx])
idx++
continue
}
}
newSQL.WriteByte(v)
}
sql = newSQL.String()
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
}
}

View File

@ -1,13 +1,39 @@
package logger_test
import (
"database/sql/driver"
"encoding/json"
"fmt"
"regexp"
"strings"
"testing"
"github.com/jinzhu/now"
"gorm.io/gorm/logger"
)
type JSON json.RawMessage
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return json.RawMessage(j).MarshalJSON()
}
type ExampleStruct struct {
Name string
Val string
}
func (s ExampleStruct) Value() (driver.Value, error) {
return json.Marshal(s)
}
func format(v []byte, escaper string) string {
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
}
func TestExplainSQL(t *testing.T) {
type role string
type password []byte
@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) {
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password([]byte("pass"))
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
)
results := []struct {
@ -29,6 +59,12 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
@ -47,6 +83,18 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
}
for idx, r := range results {

View File

@ -9,7 +9,7 @@ import (
// Migrator returns migrator
func (db *DB) Migrator() Migrator {
return db.Dialector.Migrator(db)
return db.Dialector.Migrator(db.Session(&Session{WithConditions: true}))
}
// AutoMigrate run auto migration for given models

View File

@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else {
} else if field.DefaultValue != "(-)" {
expr.SQL += " DEFAULT " + field.DefaultValue
}
}
@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
}
}
for _, idx := range stmt.Schema.ParseIndexes() {
if !tx.Migrator().HasIndex(value, idx.Name) {
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
return nil
}); err != nil {
return err
@ -297,10 +306,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType,
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
@ -354,9 +365,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true
} else {
// has size in data type and not equal
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
alterColumn = true
}
}
@ -386,8 +397,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err == nil {
defer rows.Close()
columnTypes, err = rows.ColumnTypes()
}
return err
@ -584,6 +596,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
var (
modelNames, orderedModelNames []string
orderedModelNamesMap = map[string]bool{}
parsedSchemas = map[*schema.Schema]bool{}
valuesMap = map[string]Dependency{}
insertIntoOrderedList func(name string)
parseDependence func(value interface{}, addToList bool)
@ -593,23 +606,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
dep := Dependency{
Statement: &gorm.Statement{DB: m.DB, Dest: value},
}
beDependedOn := map[*schema.Schema]bool{}
if err := dep.Parse(value); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
}
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
return
}
parsedSchemas[dep.Statement.Schema] = true
for _, rel := range dep.Schema.Relationships.Relations {
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema)
}
if rel.JoinTable != nil {
if rel.Schema != rel.FieldSchema {
dep.Depends = append(dep.Depends, rel.FieldSchema)
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true
}
if rel.JoinTable != nil {
// append join value
defer func(joinValue interface{}) {
defer func(rel *schema.Relationship, joinValue interface{}) {
if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema)
} else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd)
}
parseDependence(joinValue, autoAdd)
}(reflect.New(rel.JoinTable.ModelType).Interface())
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
}
}

View File

@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() {
db.Mux.Unlock()
}
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok {
db.Mux.RUnlock()
@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
return stmt, nil
}
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
stmt, err := db.ConnPool.PrepareContext(ctx, query)
if err == nil {
db.Stmts[query] = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(query)
stmt, err := db.prepare(ctx, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(query)
stmt, err := db.prepare(ctx, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(query)
stmt, err := db.prepare(ctx, query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
@ -99,10 +99,24 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
if err == nil {
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
@ -114,7 +128,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
if err == nil {
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
if err != nil {
@ -128,9 +142,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(query)
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}

86
scan.go
View File

@ -2,22 +2,63 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"reflect"
"strings"
"time"
"gorm.io/gorm/schema"
)
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
@ -28,41 +69,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
mapValue = *v
}
}
for idx, column := range columns {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
for idx, column := range columns {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int64, *uint, *uint64, *float32, *float64:
case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
@ -114,7 +136,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
// pluck values into slice of data
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok {
isPluck = true
} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) {
isPluck = true
}
}
for initialized || rows.Next() {
initialized = false
db.RowsAffected++

View File

@ -18,6 +18,8 @@ type DataType string
type TimeType int64
var TimeReflectType = reflect.TypeOf(time.Time{})
const (
UnixSecond TimeType = 1
UnixMillisecond TimeType = 2
@ -70,6 +72,8 @@ type Field struct {
}
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error
field := &Field{
Name: fieldStruct.Name,
BindNames: []string{fieldStruct.Name},
@ -100,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v)
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
@ -151,7 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if num, ok := field.TagSettings["SIZE"]; ok {
var err error
if field.Size, err = strconv.Atoi(num); err != nil {
field.Size = -1
}
@ -177,33 +180,42 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Comment = val
}
// default value is function or null or blank (primary keys)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Bool:
field.DataType = Bool
if field.HasDefaultValue && field.DefaultValue != "" {
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err)
}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int
if field.HasDefaultValue && field.DefaultValue != "" {
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint
if field.HasDefaultValue && field.DefaultValue != "" {
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err)
}
}
case reflect.Float32, reflect.Float64:
field.DataType = Float
if field.HasDefaultValue && field.DefaultValue != "" {
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err)
}
}
case reflect.String:
field.DataType = String
isFunc := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")")
if field.HasDefaultValue && !isFunc {
if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
field.DefaultValueInterface = field.DefaultValue
@ -211,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
field.DataType = Time
@ -316,7 +328,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
schema.err = err
}
@ -331,10 +343,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
}
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
ef.DBName = prefix + ef.DBName
}
if ef.PrimaryKey {
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
ef.PrimaryKey = true
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
@ -350,6 +363,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.HasDefaultValue = false
}
}
}
for k, v := range field.TagSettings {
ef.TagSettings[k] = v
@ -612,6 +626,14 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).SetUint(uint64(data))
case []byte:
return field.Set(value, string(data))
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
}
case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetUint(i)
@ -671,7 +693,7 @@ func (field *Field) setupValuerAndSetter() {
case []byte:
field.ReflectValueOf(value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(value).SetString(fmt.Sprint(data))
field.ReflectValueOf(value).SetString(utils.ToString(data))
case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:

View File

@ -14,7 +14,7 @@ import (
type Namer interface {
TableName(table string) string
ColumnName(table, column string) string
JoinTableName(table string) string
JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
if strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable {
return ns.TablePrefix + toDBName(str)
}
@ -49,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string {
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name))
return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1)
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return fmt.Sprintf("chk_%s_%s", table, column)
return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string {
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
idxName = strings.Replace(idxName, ".", "_", -1)
if utf8.RuneCountInString(idxName) > 64 {
h := sha1.New()

View File

@ -32,3 +32,35 @@ func TestToDBName(t *testing.T) {
}
}
}
func TestNamingStrategy(t *testing.T) {
var ns = NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.user_language" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

View File

@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
// use same data type for foreign keys
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
}
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField,
@ -254,12 +257,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
})
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: schema.Name + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields))
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
relName := relation.Schema.Name
relRefName := relation.FieldSchema.Name
@ -290,11 +299,15 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
// build references
for idx, f := range relation.JoinTable.Fields {
for _, f := range relation.JoinTable.Fields {
if f.Creatable || f.Readable || f.Updatable {
// use same data type for foreign keys
f.DataType = fieldsMap[f.Name].DataType
f.GORMDataType = fieldsMap[f.Name].GORMDataType
relation.JoinTable.PrimaryFields[idx] = f
if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if ownPriamryField {
@ -321,6 +334,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
OwnPrimaryKey: ownPriamryField,
})
}
}
}
type guessLevel int
@ -428,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
// use same data type for foreign keys
foreignField.DataType = primaryFields[idx].DataType
foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
}
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx],

View File

@ -206,16 +206,39 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{
{"ID", "User", "UserReferID", "user_profiles", "", true},
{"ID", "Profile", "ProfileRefer", "user_profiles", "", false},
{"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
},
})
}
func TestBuildReadonlyMany2ManyRelation(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{
{"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
},
})
}
@ -267,3 +290,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) {
},
)
}
func TestMultipleMany2Many(t *testing.T) {
type Thing struct {
ID int
}
type Person struct {
ID int
Likes []Thing `gorm:"many2many:likes"`
Dislikes []Thing `gorm:"many2many:dislikes"`
}
checkStructRelation(t, &Person{},
Relation{
Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "likes", Table: "likes"},
References: []Reference{
{"ID", "Person", "PersonID", "likes", "", true},
{"ID", "Thing", "ThingID", "likes", "", false},
},
},
Relation{
Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"},
References: []Reference{
{"ID", "Person", "PersonID", "dislikes", "", true},
{"ID", "Thing", "ThingID", "dislikes", "", false},
},
},
)
}

View File

@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
schema := &Schema{
Name: modelType.Name(),
@ -133,7 +136,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
@ -219,7 +222,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
}

View File

@ -1,6 +1,7 @@
package schema_test
import (
"strings"
"sync"
"testing"
@ -194,6 +195,7 @@ func TestEmbeddedStruct(t *testing.T) {
ID int
OwnerID int
Name string
Ignored string `gorm:"-"`
}
type Corp struct {
@ -211,15 +213,87 @@ func TestEmbeddedStruct(t *testing.T) {
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
}
for _, f := range fields {
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
if f.Name != "Ignored" {
f.Creatable = true
f.Updatable = true
f.Readable = true
}
})
}
}
type CustomizedNamingStrategy struct {
schema.NamingStrategy
}
func (ns CustomizedNamingStrategy) ColumnName(table, column string) string {
baseColumnName := ns.NamingStrategy.ColumnName(table, column)
if table == "" {
return baseColumnName
}
s := strings.Split(table, "_")
var prefix string
switch len(s) {
case 1:
prefix = s[0][:3]
case 2:
prefix = s[0][:1] + s[1][:2]
default:
prefix = s[0][:1] + s[1][:1] + s[2][:1]
}
return prefix + "_" + baseColumnName
}
func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
type CorpBase struct {
gorm.Model
OwnerID string
}
type Company struct {
ID int
OwnerID int
Name string
Ignored string `gorm:"-"`
}
type Corp struct {
CorpBase
Base Company `gorm:"embedded;embeddedPrefix:company_"`
}
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}})
if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
}
fields := []schema.Field{
{Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
{Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
}
for _, f := range fields {
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
if f.Name != "Ignored" {
f.Creatable = true
f.Updatable = true
f.Readable = true
}
})
}
}

View File

@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
return columns, queryValues
}
}
type embeddedNamer struct {
Table string
Namer
}

View File

@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) {
}
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{
clause.Where{Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
Value: nil,
},
}},
}
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
}
type SoftDeleteQueryClause struct {

View File

@ -12,6 +12,7 @@ import (
"sync"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance()
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance()
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
subdb.callbacks.Query().Execute(subdb)
writer.WriteString(subdb.Statement.SQL.String())
@ -298,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else {
values := make([]interface{}, reflectValue.Len())
for i := 0; i < reflectValue.Len(); i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
@ -317,9 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
if field.Readable {
if v, isZero := field.ValueOf(reflectValue); !isZero {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
@ -330,9 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
if field.Readable {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}

View File

@ -1,6 +1,7 @@
package tests_test
import (
"errors"
"testing"
"time"
@ -287,6 +288,30 @@ func TestCreateEmptyStruct(t *testing.T) {
}
}
func TestCreateEmptySlice(t *testing.T) {
var data = []User{}
if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice {
t.Errorf("no data should be created, got %v", err)
}
var sliceMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice {
t.Errorf("no data should be created, got %v", err)
}
}
func TestCreateInvalidSlice(t *testing.T) {
users := []*User{
GetUser("invalid_slice_1", Config{}),
GetUser("invalid_slice_2", Config{}),
nil,
}
if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) {
t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
}
}
func TestCreateWithExistingTimestamp(t *testing.T) {
user := User{Name: "CreateUserExistingTimestamp"}
curTime := now.MustParse("2016-01-01")

View File

@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) {
FieldAllowSave3 string `gorm:"->:false;<-:create"`
FieldReadonly string `gorm:"->"`
FieldIgnore string `gorm:"-"`
AutoUnixCreateTime int64 `gorm:"autocreatetime"`
AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"`
AutoUnixCreateTime int32 `gorm:"autocreatetime"`
AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"`
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"`
AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"`
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"`
AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"`
AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"`
AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"`
}
DB.Migrator().DropTable(&CustomizeFieldStruct{})
@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) {
t.Fatalf("invalid result: %#v", result)
}
if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 {
if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 {
t.Fatalf("invalid create/update unix time: %#v", result)
}
if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 {
if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 {
t.Fatalf("invalid create/update unix milli time: %#v", result)
}
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 {
if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 {
t.Fatalf("invalid create/update unix nano time: %#v", result)
}
@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) {
var createWithDefaultTimeResult CustomizeFieldStruct
DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name)
if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
}
if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult)
}
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
}
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) {
t.Errorf("should returns no error while enable global update, but got err %v", err)
}
}
func TestDeleteWithAssociations(t *testing.T) {
user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1})
if err := DB.Create(user).Error; err != nil {
t.Fatalf("failed to create user, got error %v", err)
}
if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil {
t.Fatalf("failed to delete user, got error %v", err)
}
for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} {
if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} {
if count := DB.Model(&user).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
}
func TestDeleteSliceWithAssociations(t *testing.T) {
users := []User{
*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}),
*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}),
*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}),
*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}),
}
if err := DB.Create(users).Error; err != nil {
t.Fatalf("failed to create user, got error %v", err)
}
if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil {
t.Fatalf("failed to delete user, got error %v", err)
}
for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} {
if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} {
if count := DB.Model(&users).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
}

View File

@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) {
DB.Migrator().DropTable(&AdvancedUser{})
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
if DB.Dialector.Name() != "sqlite" {
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
}
}
}

View File

@ -7,11 +7,13 @@ require (
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
github.com/stevefan1999-personal/gorm-driver-oracle latest
gorm.io/driver/mysql v1.0.0
gorm.io/driver/postgres v1.0.0
gorm.io/driver/sqlite v1.1.0
gorm.io/driver/sqlserver v1.0.1
gorm.io/gorm v1.9.19
gorm.io/driver/mysql v1.0.1
gorm.io/driver/postgres v1.0.2
gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.4
gorm.io/gorm v1.20.2
)
replace gorm.io/gorm => ../
replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2

View File

@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Pets", func(t *testing.T) {
if len(user.Pets) != len(expect.Pets) {
t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
}
sort.Slice(user.Pets, func(i, j int) bool {
@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Toys", func(t *testing.T) {
if len(user.Toys) != len(expect.Toys) {
t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
}
sort.Slice(user.Toys, func(i, j int) bool {
@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Team", func(t *testing.T) {
if len(user.Team) != len(expect.Team) {
t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
}
sort.Slice(user.Team, func(i, j int) bool {
@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Languages", func(t *testing.T) {
if len(user.Languages) != len(expect.Languages) {
t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
}
sort.Slice(user.Languages, func(i, j int) bool {
@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Friends", func(t *testing.T) {
if len(user.Friends) != len(expect.Friends) {
t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
}
sort.Slice(user.Friends, func(i, j int) bool {

View File

@ -5,12 +5,14 @@ import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Person struct {
ID int
Name string
Addresses []Address `gorm:"many2many:person_addresses;"`
DeletedAt gorm.DeletedAt
}
type Address struct {
@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) {
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
t.Fatalf("address should be deleted when clear with unscoped")
}
address2_1 := Address{Name: "address 2-1"}
address2_2 := Address{Name: "address 2-2"}
person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}}
DB.Create(&person2)
if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil {
t.Fatalf("failed to delete person, got error: %v", err)
}
if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 {
t.Errorf("person's addresses expects 2, got %v", count)
}
if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 {
t.Errorf("person's addresses expects 2, got %v", count)
}
}

View File

@ -6,6 +6,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
type Blog struct {
@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
t.Fatalf("EN Blog's tags should be cleared")
}
}
func TestCompositePrimaryKeysAssociations(t *testing.T) {
type Label struct {
BookID *uint `gorm:"primarykey"`
Name string `gorm:"primarykey"`
Value string
}
type Book struct {
ID int
Name string
Labels []Label
}
DB.Migrator().DropTable(&Label{}, &Book{})
if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil {
t.Fatalf("failed to migrate")
}
book := Book{
Name: "my book",
Labels: []Label{
{Name: "region", Value: "emea"},
},
}
DB.Create(&book)
var result Book
if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil {
t.Fatalf("failed to preload, got error %v", err)
}
AssertEqual(t, book, result)
}

View File

@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) {
t.Errorf("failed to update with named arg")
}
namedUser.Name1 = "jinzhu-new"
namedUser.Name2 = "jinzhu-new2"
namedUser.Name3 = "jinzhu-new"
var result5 NamedUser
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil {
t.Errorf("failed to update with named arg")
}
AssertEqual(t, result4, namedUser)
AssertEqual(t, result5, namedUser)
var result6 NamedUser
if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name
AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil {
t.Errorf("failed to update with named arg")
}
AssertEqual(t, result6, namedUser)
}

View File

@ -1,6 +1,8 @@
package tests_test
import (
"encoding/json"
"regexp"
"sort"
"strconv"
"testing"
@ -31,6 +33,20 @@ func TestPreloadWithAssociations(t *testing.T) {
var user2 User
DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
var user3 = *GetUser("preload_with_associations_new", Config{
Account: true,
Pets: 2,
Toys: 3,
Company: true,
Manager: true,
Team: 4,
Languages: 3,
Friends: 1,
})
DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID)
CheckUser(t, user3, user)
}
func TestNestedPreload(t *testing.T) {
@ -174,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) {
CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2])
}
}
func TestPreloadEmptyData(t *testing.T) {
var user = *GetUser("user_without_associations", Config{})
DB.Create(&user)
DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name)
if r, err := json.Marshal(&user); err != nil {
t.Errorf("failed to marshal users, got error %v", err)
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
var results []User
DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name)
if r, err := json.Marshal(&results); err != nil {
t.Errorf("failed to marshal users, got error %v", err)
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
}

View File

@ -0,0 +1,52 @@
package tests_test
import (
"context"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestPreparedStmt(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
txCtx := tx.WithContext(ctx)
user := *GetUser("prepared_stmt", Config{})
txCtx.Create(&user)
var result1 User
if err := txCtx.Find(&result1, user.ID).Error; err != nil {
t.Fatalf("no error should happen but got %v", err)
}
time.Sleep(time.Second)
var result2 User
if err := tx.Find(&result2, user.ID).Error; err != nil {
t.Fatalf("no error should happen but got %v", err)
}
user2 := *GetUser("prepared_stmt2", Config{})
if err := txCtx.Create(&user2).Error; err == nil {
t.Fatalf("should failed to create with timeout context")
}
if err := tx.Create(&user2).Error; err != nil {
t.Fatalf("failed to create, got error %v", err)
}
var result3 User
if err := tx.Find(&result3, user2.ID).Error; err != nil {
t.Fatalf("no error should happen but got %v", err)
}
}

View File

@ -1,11 +1,13 @@
package tests_test
import (
"database/sql"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"testing"
"time"
@ -61,6 +63,54 @@ func TestFind(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := first[dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Age":
if _, ok := first[dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Birthday":
if _, ok := first[dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
}
}
})
t.Run("FirstMapWithTable", func(t *testing.T) {
var first = map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(first[dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
@ -86,13 +136,29 @@ func TestFind(t *testing.T) {
t.Run("FirstSliceOfMap", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
t.Errorf("errors happened when query find: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := allMap[idx][dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Age":
if _, ok := allMap[idx][dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Birthday":
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
@ -101,6 +167,58 @@ func TestFind(t *testing.T) {
}
}
})
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query find: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
}
})
}
}
})
var models []User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models))
} else {
for idx, user := range users {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models[idx], user)
})
}
}
var none []User
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
}
}
func TestQueryWithAssociation(t *testing.T) {
@ -228,6 +346,11 @@ func TestNot(t *testing.T) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Not("name = ?", "jinzhu").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
@ -309,6 +432,33 @@ func TestPluck(t *testing.T) {
t.Errorf("Unexpected result on pluck id, got %+v", ids)
}
}
var times []time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &times).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range times {
AssertEqual(t, tv, users[idx].CreatedAt)
}
var ptrtimes []*time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range ptrtimes {
AssertEqual(t, tv, users[idx].CreatedAt)
}
var nulltimes []sql.NullTime
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range nulltimes {
AssertEqual(t, tv.Time, users[idx].CreatedAt)
}
}
func TestSelect(t *testing.T) {
@ -508,6 +658,7 @@ func TestLimit(t *testing.T) {
{Name: "LimitUser3", Age: 20},
{Name: "LimitUser4", Age: 10},
{Name: "LimitUser5", Age: 20},
{Name: "LimitUser6", Age: 20},
}
DB.Create(&users)
@ -516,7 +667,7 @@ func TestLimit(t *testing.T) {
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
t.Errorf("Limit should works")
t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3))
}
}
@ -531,6 +682,7 @@ func TestOffset(t *testing.T) {
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
t.Errorf("Offset should work")
}
DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
@ -698,3 +850,11 @@ func TestScanNullValue(t *testing.T) {
t.Fatalf("failed to query slice data with null age, got error %v", err)
}
}
func TestQueryWithTableAndConditions(t *testing.T) {
result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"})
if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
}

View File

@ -6,6 +6,7 @@ import (
"strings"
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -16,14 +17,25 @@ func TestScan(t *testing.T) {
DB.Save(&user1).Save(&user2).Save(&user3)
type result struct {
ID uint
Name string
Age int
}
var res result
DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res)
if res.Name != user3.Name || res.Age != int(user3.Age) {
t.Errorf("Scan into struct should work")
DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res)
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
}
DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res)
if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2)
}
DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res)
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
}
var doubleAgeRes = &result{}
@ -39,11 +51,11 @@ func TestScan(t *testing.T) {
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
sort.Slice(results, func(i, j int) bool {
return strings.Compare(results[i].Name, results[j].Name) < -1
return strings.Compare(results[i].Name, results[j].Name) <= -1
})
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
t.Errorf("Scan into struct map")
t.Errorf("Scan into struct map, got %#v", results)
}
}
@ -72,7 +84,21 @@ func TestScanRows(t *testing.T) {
results = append(results, result)
}
sort.Slice(results, func(i, j int) bool {
return strings.Compare(results[i].Name, results[j].Name) <= -1
})
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
t.Errorf("Should find expected results")
}
var ages int
if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages)
}
var name string
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
}
}

View File

@ -1,6 +1,7 @@
package tests_test
import (
"regexp"
"strings"
"testing"
@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) {
t.Errorf("expects: %v, got %v", expects, result)
}
}
func TestCombineStringConditions(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
}

View File

@ -13,7 +13,7 @@ if [ -d tests ]
then
cd tests
cp go.mod go.mod.bak
sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod
sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod
cd ..
fi

View File

@ -22,7 +22,7 @@ var DB *gorm.DB
func init() {
var err error
if DB, err = OpenTestConnection(); err != nil {
log.Printf("failed to connect database, got error %v\n", err)
log.Printf("failed to connect database, got error %v", err)
os.Exit(1)
} else {
sqlDB, err := DB.DB()
@ -31,7 +31,7 @@ func init() {
}
if err != nil {
log.Printf("failed to connect database, got error %v\n", err)
log.Printf("failed to connect database, got error %v", err)
}
RunMigrations()

View File

@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) {
t.Fatalf("Should find saved record")
}
}
func TestTransactionOnClosedConn(t *testing.T) {
DB, err := OpenTestConnection()
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
rawDB, _ := DB.DB()
rawDB.Close()
if err := DB.Transaction(func(tx *gorm.DB) error {
return nil
}); err == nil {
t.Errorf("should returns error when commit with closed conn, got error %v", err)
}
if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error {
return nil
}); err == nil {
t.Errorf("should returns error when commit with closed conn, got error %v", err)
}
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
var user2 User
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
user.Company.Name += "new"
user.Manager.Name += "new"
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
for _, pet := range user.Pets {
pet.Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
t.Run("Polymorphic", func(t *testing.T) {
var user = *GetUser("update-has-many", Config{})
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
var user2 User
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
for idx := range user.Toys {
user.Toys[idx].Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
})
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
user.Account.Number += "new"
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
t.Run("Polymorphic", func(t *testing.T) {
var pet = Pet{Name: "create"}
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
var pet2 Pet
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
CheckPet(t, pet2, pet)
pet.Toy.Name += "new"
if err := DB.Save(&pet).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var pet3 Pet
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
CheckPet(t, pet2, pet3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var pet4 Pet
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
CheckPet(t, pet4, pet)
})
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
var user2 User
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
for idx := range user.Friends {
user.Friends[idx].Name += "new"
}
for idx := range user.Languages {
user.Languages[idx].Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
}

View File

@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) {
var result2 Language
DB.First(&result2, "code = ?", "save")
AssertEqual(t, result2, lang)
DB.Table("langs").Migrator().DropTable(&Language{})
DB.Table("langs").AutoMigrate(&Language{})
if err := DB.Table("langs").Save(&lang).Error; err != nil {
t.Errorf("no error should happen when creating data, but got %v", err)
}
var result3 Language
if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name {
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3)
}
lang.Name += "name2"
if err := DB.Table("langs").Save(&lang).Error; err != nil {
t.Errorf("no error should happen when creating data, but got %v", err)
}
var result4 Language
if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name {
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
}
}

View File

@ -30,7 +30,7 @@ func FileWithLineNum() string {
}
func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
}
func CheckTruth(val interface{}) bool {
@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool {
}
return true
}
func ToString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case int:
return strconv.FormatInt(int64(v), 10)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case uint:
return strconv.FormatUint(uint64(v), 10)
case uint8:
return strconv.FormatUint(uint64(v), 10)
case uint16:
return strconv.FormatUint(uint64(v), 10)
case uint32:
return strconv.FormatUint(uint64(v), 10)
case uint64:
return strconv.FormatUint(v, 10)
}
return ""
}