Merge branch 'master' into patch-oracle

This commit is contained in:
Steve Fan 2020-10-03 23:15:42 +08:00 committed by GitHub
commit b2fdc25f73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 1609 additions and 242 deletions

5
.github/labels.json vendored
View File

@ -10,6 +10,11 @@
"colour": "#EDEDED", "colour": "#EDEDED",
"description": "general questions" "description": "general questions"
}, },
"invalid_question": {
"name": "type:invalid question",
"colour": "#CF2E1F",
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
},
"with_playground": { "with_playground": {
"name": "type:with reproduction steps", "name": "type:with reproduction steps",
"colour": "#00ff00", "colour": "#00ff00",

22
.github/workflows/invalid_question.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: "Close invalid questions issues"
on:
schedule:
- cron: "*/10 * * * *"
jobs:
stale:
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v3.0.7
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
remove-stale-when-updated: true
only-labels: "type:invalid question"

View File

@ -13,7 +13,7 @@ jobs:
uses: actions/stale@v3.0.7 uses: actions/stale@v3.0.7
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
days-before-stale: 0 days-before-stale: 0
days-before-close: 2 days-before-close: 2

View File

@ -8,7 +8,6 @@ import (
"sort" "sort"
"time" "time"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
@ -75,15 +74,20 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) {
curTime := time.Now() curTime := time.Now()
stmt := db.Statement stmt := db.Statement
db.RowsAffected = 0
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
} }
if stmt.Model != nil { if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
db.AddError(err) if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
} }
} }
@ -154,7 +158,7 @@ func (p *processor) compile() (err error) {
p.callbacks = callbacks p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil { if p.fns, err = sortCallbacks(p.callbacks); err != nil {
logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
} }
return return
} }
@ -177,7 +181,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
} }
func (c *callback) Remove(name string) error { func (c *callback) Remove(name string) error {
logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name c.name = name
c.remove = true c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c) c.processor.callbacks = append(c.processor.callbacks, c)
@ -185,7 +189,7 @@ func (c *callback) Remove(name string) error {
} }
func (c *callback) Replace(name string, fn func(*DB)) error { func (c *callback) Replace(name string, fn func(*DB)) error {
logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name c.name = name
c.handler = fn c.handler = fn
c.replace = true c.replace = true
@ -215,7 +219,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
for _, c := range cs { for _, c := range cs {
// show warning message the callback name already exists // show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
} }
names = append(names, c.name) names = append(names, c.name)
} }

View File

@ -5,6 +5,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema"
) )
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
@ -65,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
} }
if elems.Len() > 0 { if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
DoNothing: true,
}).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i)) setupReferences(objs[i], elems.Index(i))
} }
@ -80,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
rv = rv.Addr() rv = rv.Addr()
} }
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
DoNothing: true,
}).Create(rv.Interface()).Error) == nil {
setupReferences(db.Statement.ReflectValue, rv) setupReferences(db.Statement.ReflectValue, rv)
} }
} }
@ -144,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ db.AddError(db.Session(&gorm.Session{}).Clauses(
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
DoUpdates: clause.AssignmentColumns(assignmentColumns), ).Create(elems.Interface()).Error)
}).Create(elems.Interface()).Error)
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
@ -167,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ db.AddError(db.Session(&gorm.Session{}).Clauses(
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
DoUpdates: clause.AssignmentColumns(assignmentColumns), ).Create(f.Interface()).Error)
}).Create(f.Interface()).Error)
} }
} }
} }
@ -229,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ db.AddError(db.Session(&gorm.Session{}).Clauses(
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
DoUpdates: clause.AssignmentColumns(assignmentColumns), ).Create(elems.Interface()).Error)
}).Create(elems.Interface()).Error)
} }
} }
@ -297,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
if elems.Len() > 0 { if elems.Len() > 0 {
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i)) appendToJoins(objs[i], elems.Index(i))
@ -310,3 +304,32 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
} }
} }
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
}
}
if len(defaultUpdatingColumns) > 0 {
var columns []clause.Column
if s.PrioritizedPrimaryField != nil {
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
} else {
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
}
}
return clause.OnConflict{
Columns: columns,
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
}
}
return clause.OnConflict{DoNothing: true}
}

View File

@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback := db.Callback().Delete() deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

View File

@ -1,6 +1,7 @@
package callbacks package callbacks
import ( import (
"fmt"
"reflect" "reflect"
"gorm.io/gorm" "gorm.io/gorm"
@ -87,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) {
} }
case reflect.Struct: case reflect.Struct:
if insertID > 0 { if insertID > 0 {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
} }
} }
} else { } else {
@ -252,8 +256,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) stmt.SQL.Grow(stmt.ReflectValue.Len() * 15)
values.Values = make([][]interface{}, stmt.ReflectValue.Len()) values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
if stmt.ReflectValue.Len() == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i)) rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns)) values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]

View File

@ -2,6 +2,7 @@ package callbacks
import ( import (
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -21,6 +22,85 @@ func BeforeDelete(db *gorm.DB) {
} }
} }
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if restricted {
for column, v := range selectColumns {
if v {
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{}).Model(modelValue)
withoutConditions := false
if len(db.Statement.Selects) > 0 {
var selects []string
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if strings.HasPrefix(s, column+".") {
selects = append(selects, strings.TrimPrefix(s, column+"."))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions {
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
case schema.Many2Many:
var (
queryConds []clause.Expression
foreignFields []*schema.Field
relForeignKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
}
}
}
func Delete(db *gorm.DB) { func Delete(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {

View File

@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
) )
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for idx, mapValue := range mapValues { for idx, mapValue := range mapValues {
for k, v := range mapValue { for k, v := range mapValue {
if stmt.Schema != nil { if stmt.Schema != nil {

View File

@ -107,6 +107,26 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
fieldValues := make([]interface{}, len(relForeignFields)) fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
}
}
}
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {

View File

@ -11,11 +11,13 @@ func RowQuery(db *gorm.DB) {
} }
if !db.DryRun { if !db.DryRun {
if _, ok := db.Get("rows"); ok { if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else { } else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} }
db.RowsAffected = -1
} }
} }
} }

View File

@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
// Order specify order when retrieve records from database // Order specify order when retrieve records from database
// db.Order("name DESC") // db.Order("name DESC")
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()

View File

@ -3,6 +3,7 @@ package clause
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"go/ast"
"reflect" "reflect"
) )
@ -30,18 +31,22 @@ func (expr Expr) Build(builder Builder) {
) )
for _, v := range []byte(expr.SQL) { for _, v := range []byte(expr.SQL) {
if v == '?' { if v == '?' && len(expr.Vars) > idx {
if afterParenthesis { if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok { if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx]) builder.AddVar(builder, expr.Vars[idx])
} else { } else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ { if rv.Len() == 0 {
if i > 0 { builder.AddVar(builder, nil)
builder.WriteByte(',') } else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
} }
builder.AddVar(builder, rv.Index(i).Interface())
} }
default: default:
builder.AddVar(builder, expr.Vars[idx]) builder.AddVar(builder, expr.Vars[idx])
@ -85,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) {
for k, v := range value { for k, v := range value {
namedMap[k] = v namedMap[k] = v
} }
default:
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
}
}
}
} }
} }
@ -94,7 +110,7 @@ func (expr NamedExpr) Build(builder Builder) {
if v == '@' && !inName { if v == '@' && !inName {
inName = true inName = true
name = []byte{} name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
if inName { if inName {
if nv, ok := namedMap[string(name)]; ok { if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv) builder.AddVar(builder, nv)
@ -106,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) {
} }
builder.WriteByte(v) builder.WriteByte(v)
} else if v == '?' { } else if v == '?' && len(expr.Vars) > idx {
builder.AddVar(builder, expr.Vars[idx]) builder.AddVar(builder, expr.Vars[idx])
idx++ idx++
} else if inName { } else if inName {

View File

@ -37,6 +37,11 @@ func TestExpr(t *testing.T) {
} }
func TestNamedExpr(t *testing.T) { func TestNamedExpr(t *testing.T) {
type NamedArgument struct {
Name1 string
Name2 string
}
results := []struct { results := []struct {
SQL string SQL string
Result string Result string
@ -66,6 +71,15 @@ func TestNamedExpr(t *testing.T) {
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
Result: "create table ? (? ?, ? ?)",
}} }}
for idx, result := range results { for idx, result := range results {

View File

@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit > 0 { if limit.Limit == 0 && v.Limit != 0 {
limit.Limit = v.Limit limit.Limit = v.Limit
} else if limit.Limit < 0 {
limit.Limit = 0
} }
if limit.Offset == 0 && v.Offset > 0 { if limit.Offset == 0 && v.Offset > 0 {

View File

@ -1,5 +1,9 @@
package clause package clause
import (
"strings"
)
// Where where clause // Where where clause
type Where struct { type Where struct {
Exprs []Expression Exprs []Expression
@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) {
} }
} }
wrapInParentheses := false
for idx, expr := range where.Exprs { for idx, expr := range where.Exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) {
} }
} }
expr.Build(builder) if len(where.Exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case Expr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
if wrapInParentheses {
builder.WriteString(`(`)
expr.Build(builder)
builder.WriteString(`)`)
wrapInParentheses = false
} else {
expr.Build(builder)
}
} }
} }
@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) {
func And(exprs ...Expression) Expression { func And(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} else if len(exprs) == 1 {
return exprs[0]
} }
return AndConditions{Exprs: exprs} return AndConditions{Exprs: exprs}
} }
@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} }
for idx, c := range not.Exprs { for idx, c := range not.Exprs {
if idx > 0 { if idx > 0 {
builder.WriteString(" AND ") builder.WriteString(" AND ")
@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) {
negationBuilder.NegationBuild(builder) negationBuilder.NegationBuild(builder)
} else { } else {
builder.WriteString("NOT ") builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder) c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
} }
} }
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} }

View File

@ -27,4 +27,8 @@ var (
ErrRegistered = errors.New("registered") ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field // ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field") ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
) )

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
@ -51,7 +52,8 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.callbacks.Update().Execute(tx) tx.callbacks.Update().Execute(tx)
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value) return tx.Create(value)
} }
} }
@ -331,22 +333,43 @@ func (db *DB) Count(count *int64) (tx *DB) {
} }
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
tx := db.getInstance() tx := db.getInstance().InstanceSet("rows", false)
tx.callbacks.Row().Execute(tx) tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row) row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
} }
func (db *DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
tx := db.Set("rows", true) tx := db.getInstance().InstanceSet("rows", true)
tx.callbacks.Row().Execute(tx) tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
} }
// Scan scan value to a struct // Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
currentLogger, newLogger := db.Logger, logger.Recorder.New()
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = dest tx.Logger = newLogger
tx.callbacks.Query().Execute(tx) if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {
defer rows.Close()
if rows.Next() {
tx.ScanRows(rows, dest)
}
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
return return
} }
@ -377,9 +400,14 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance() tx := db.getInstance()
tx.Error = tx.Statement.Parse(dest) if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
}
Scan(rows, tx, true) Scan(rows, tx, true)
return tx.Error return tx.Error
} }

13
gorm.go
View File

@ -20,6 +20,8 @@ type Config struct {
SkipDefaultTransaction bool SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger // Logger
Logger logger.Interface Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp // NowFunc the function to be used when creating a new timestamp
@ -64,6 +66,7 @@ type Session struct {
WithConditions bool WithConditions bool
SkipDefaultTransaction bool SkipDefaultTransaction bool
AllowGlobalUpdate bool AllowGlobalUpdate bool
FullSaveAssociations bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
NowFunc func() time.Time NowFunc func() time.Time
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.AllowGlobalUpdate = true txConfig.AllowGlobalUpdate = true
} }
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.Context != nil { if config.Context != nil {
tx.Statement = tx.Statement.clone() tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx tx.Statement.DB = tx
@ -169,12 +176,15 @@ func (db *DB) Session(config *Session) *DB {
if config.PrepareStmt { if config.PrepareStmt {
if v, ok := db.cacheStore.Load("preparedStmt"); ok { if v, ok := db.cacheStore.Load("preparedStmt"); ok {
tx.Statement = tx.Statement.clone()
preparedStmt := v.(*PreparedStmtDB) preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{ tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool, ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux, Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts, Stmts: preparedStmt.Stmts,
} }
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
} }
@ -316,6 +326,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f ref.ForeignKey = f
} else { } else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)

View File

@ -2,6 +2,7 @@ package logger
import ( import (
"context" "context"
"io/ioutil"
"log" "log"
"os" "os"
"time" "time"
@ -19,6 +20,7 @@ const (
Magenta = "\033[35m" Magenta = "\033[35m"
Cyan = "\033[36m" Cyan = "\033[36m"
White = "\033[37m" White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m" MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m" RedBold = "\033[31;1m"
YellowBold = "\033[33;1m" YellowBold = "\033[33;1m"
@ -54,29 +56,33 @@ type Interface interface {
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
} }
var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ var (
SlowThreshold: 100 * time.Millisecond, Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
LogLevel: Warn, Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
Colorful: true, SlowThreshold: 100 * time.Millisecond,
}) LogLevel: Warn,
Colorful: true,
})
Recorder = traceRecorder{Interface: Default}
)
func New(writer Writer, config Config) Interface { func New(writer Writer, config Config) Interface {
var ( var (
infoStr = "%s\n[info] " infoStr = "%s\n[info] "
warnStr = "%s\n[warn] " warnStr = "%s\n[warn] "
errStr = "%s\n[error] " errStr = "%s\n[error] "
traceStr = "%s\n[%v] [rows:%d] %s" traceStr = "%s\n[%.3fms] [rows:%v] %s"
traceWarnStr = "%s\n[%v] [rows:%d] %s" traceWarnStr = "%s\n[%.3fms] [rows:%v] %s"
traceErrStr = "%s %s\n[%v] [rows:%d] %s" traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
) )
if config.Colorful { if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
} }
return &logger{ return &logger{
@ -133,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
switch { switch {
case err != nil && l.LogLevel >= Error: case err != nil && l.LogLevel >= Error:
sql, rows := fc() sql, rows := fc()
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc() sql, rows := fc()
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel >= Info: case l.LogLevel >= Info:
sql, rows := fc() sql, rows := fc()
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
} }
} }
} }
type traceRecorder struct {
Interface
BeginAt time.Time
SQL string
RowsAffected int64
Err error
}
func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface}
}
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}

View File

@ -9,6 +9,8 @@ import (
"strings" "strings"
"time" "time"
"unicode" "unicode"
"gorm.io/gorm/utils"
) )
func isPrintable(s []byte) bool { func isPrintable(s []byte) bool {
@ -24,19 +26,38 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int) var convertParams func(interface{}, int)
var vars = make([]interface{}, len(avars)) var vars = make([]string, len(avars))
copy(vars, avars)
convertParams = func(v interface{}, idx int) { convertParams = func(v interface{}, idx int) {
switch v := v.(type) { switch v := v.(type) {
case bool: case bool:
vars[idx] = fmt.Sprint(v) vars[idx] = strconv.FormatBool(v)
case time.Time: case time.Time:
if v.IsZero() { if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else { } else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
} }
case *time.Time:
if v != nil {
if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
}
} else {
vars[idx] = "NULL"
}
case fmt.Stringer:
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
case driver.Valuer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
r, _ := v.Value()
convertParams(r, idx)
} else {
vars[idx] = "NULL"
}
case []byte: case []byte:
if isPrintable(v) { if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
@ -44,7 +65,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
vars[idx] = escaper + "<binary>" + escaper vars[idx] = escaper + "<binary>" + escaper
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = fmt.Sprintf("%d", v) vars[idx] = utils.ToString(v)
case float64, float32: case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v) vars[idx] = fmt.Sprintf("%.6f", v)
case string: case string:
@ -70,18 +91,30 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
} }
} }
for idx, v := range vars { for idx, v := range avars {
convertParams(v, idx) convertParams(v, idx)
} }
if numericPlaceholder == nil { if numericPlaceholder == nil {
for _, v := range vars { var idx int
sql = strings.Replace(sql, "?", v.(string), 1) var newSQL strings.Builder
for _, v := range []byte(sql) {
if v == '?' {
if len(vars) > idx {
newSQL.WriteString(vars[idx])
idx++
continue
}
}
newSQL.WriteByte(v)
} }
sql = newSQL.String()
} else { } else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars { for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
} }
} }

View File

@ -1,13 +1,39 @@
package logger_test package logger_test
import ( import (
"database/sql/driver"
"encoding/json"
"fmt"
"regexp" "regexp"
"strings"
"testing" "testing"
"github.com/jinzhu/now" "github.com/jinzhu/now"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
type JSON json.RawMessage
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return json.RawMessage(j).MarshalJSON()
}
type ExampleStruct struct {
Name string
Val string
}
func (s ExampleStruct) Value() (driver.Value, error) {
return json.Marshal(s)
}
func format(v []byte, escaper string) string {
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
}
func TestExplainSQL(t *testing.T) { func TestExplainSQL(t *testing.T) {
type role string type role string
type password []byte type password []byte
@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) {
tt = now.MustParse("2020-02-23 11:10:10") tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin") myrole = role("admin")
pwd = password([]byte("pass")) pwd = password([]byte("pass"))
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
) )
results := []struct { results := []struct {
@ -29,6 +59,12 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
}, },
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
@ -47,6 +83,18 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
} }
for idx, r := range results { for idx, r := range results {

View File

@ -9,7 +9,7 @@ import (
// Migrator returns migrator // Migrator returns migrator
func (db *DB) Migrator() Migrator { func (db *DB) Migrator() Migrator {
return db.Dialector.Migrator(db) return db.Dialector.Migrator(db.Session(&Session{WithConditions: true}))
} }
// AutoMigrate run auto migration for given models // AutoMigrate run auto migration for given models

View File

@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else { } else if field.DefaultValue != "(-)" {
expr.SQL += " DEFAULT " + field.DefaultValue expr.SQL += " DEFAULT " + field.DefaultValue
} }
} }
@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} }
} }
for _, idx := range stmt.Schema.ParseIndexes() {
if !tx.Migrator().HasIndex(value, idx.Name) {
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
return nil return nil
}); err != nil { }); err != nil {
return err return err
@ -297,10 +306,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType,
).Error ).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -354,9 +365,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true alterColumn = true
} else { } else {
// has size in data type and not equal // has size in data type and not equal
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
alterColumn = true alterColumn = true
} }
} }
@ -386,8 +397,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error { err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err == nil { if err == nil {
defer rows.Close()
columnTypes, err = rows.ColumnTypes() columnTypes, err = rows.ColumnTypes()
} }
return err return err
@ -584,6 +596,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
var ( var (
modelNames, orderedModelNames []string modelNames, orderedModelNames []string
orderedModelNamesMap = map[string]bool{} orderedModelNamesMap = map[string]bool{}
parsedSchemas = map[*schema.Schema]bool{}
valuesMap = map[string]Dependency{} valuesMap = map[string]Dependency{}
insertIntoOrderedList func(name string) insertIntoOrderedList func(name string)
parseDependence func(value interface{}, addToList bool) parseDependence func(value interface{}, addToList bool)
@ -593,23 +606,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
dep := Dependency{ dep := Dependency{
Statement: &gorm.Statement{DB: m.DB, Dest: value}, Statement: &gorm.Statement{DB: m.DB, Dest: value},
} }
beDependedOn := map[*schema.Schema]bool{}
if err := dep.Parse(value); err != nil { if err := dep.Parse(value); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
} }
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
return
}
parsedSchemas[dep.Statement.Schema] = true
for _, rel := range dep.Schema.Relationships.Relations { for _, rel := range dep.Schema.Relationships.Relations {
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema) dep.Depends = append(dep.Depends, c.ReferenceSchema)
} }
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true
}
if rel.JoinTable != nil { if rel.JoinTable != nil {
if rel.Schema != rel.FieldSchema {
dep.Depends = append(dep.Depends, rel.FieldSchema)
}
// append join value // append join value
defer func(joinValue interface{}) { defer func(rel *schema.Relationship, joinValue interface{}) {
if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema)
} else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd)
}
parseDependence(joinValue, autoAdd) parseDependence(joinValue, autoAdd)
}(reflect.New(rel.JoinTable.ModelType).Interface()) }(rel, reflect.New(rel.JoinTable.ModelType).Interface())
} }
} }

View File

@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() {
db.Mux.Unlock() db.Mux.Unlock()
} }
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) {
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok { if stmt, ok := db.Stmts[query]; ok {
db.Mux.RUnlock() db.Mux.RUnlock()
@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
return stmt, nil return stmt, nil
} }
stmt, err := db.ConnPool.PrepareContext(context.Background(), query) stmt, err := db.ConnPool.PrepareContext(ctx, query)
if err == nil { if err == nil {
db.Stmts[query] = stmt db.Stmts[query] = stmt
db.PreparedSQL = append(db.PreparedSQL, query) db.PreparedSQL = append(db.PreparedSQL, query)
@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
} }
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(query) stmt, err := db.prepare(ctx, query)
if err == nil { if err == nil {
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if err != nil { if err != nil {
@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
} }
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(query) stmt, err := db.prepare(ctx, query)
if err == nil { if err == nil {
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if err != nil { if err != nil {
@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
} }
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(query) stmt, err := db.prepare(ctx, query)
if err == nil { if err == nil {
return stmt.QueryRowContext(ctx, args...) return stmt.QueryRowContext(ctx, args...)
} }
@ -99,10 +99,24 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query) stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
if err == nil { if err == nil {
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
stmt.Close() stmt.Close()
@ -114,7 +128,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
} }
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query) stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
if err == nil { if err == nil {
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
if err != nil { if err != nil {
@ -128,9 +142,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
} }
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(query) stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
if err == nil { if err == nil {
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)
} }
return &sql.Row{} return &sql.Row{}
} }

86
scan.go
View File

@ -2,22 +2,63 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"reflect" "reflect"
"strings" "strings"
"time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func Scan(rows *sql.Rows, db *DB, initialized bool) { func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns() columns, _ := rows.Columns()
values := make([]interface{}, len(columns)) values := make([]interface{}, len(columns))
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}: case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() { if initialized || rows.Next() {
for idx := range columns { columnTypes, _ := rows.ColumnTypes()
values[idx] = new(interface{}) prepareValues(values, db, columnTypes, columns)
}
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
@ -28,41 +69,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
mapValue = *v mapValue = *v
} }
} }
scanIntoMap(mapValue, values, columns)
for idx, column := range columns {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
} }
case *[]map[string]interface{}: case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() { for initialized || rows.Next() {
for idx := range columns { prepareValues(values, db, columnTypes, columns)
values[idx] = new(interface{})
}
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{} mapValue := map[string]interface{}{}
for idx, column := range columns { scanIntoMap(mapValue, values, columns)
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
*dest = append(*dest, mapValue) *dest = append(*dest, mapValue)
} }
case *int, *int64, *uint, *uint64, *float32, *float64: case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time:
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
@ -114,7 +136,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
// pluck values into slice of data // pluck values into slice of data
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok {
isPluck = true
} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) {
isPluck = true
}
}
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++

View File

@ -18,6 +18,8 @@ type DataType string
type TimeType int64 type TimeType int64
var TimeReflectType = reflect.TypeOf(time.Time{})
const ( const (
UnixSecond TimeType = 1 UnixSecond TimeType = 1
UnixMillisecond TimeType = 2 UnixMillisecond TimeType = 2
@ -70,6 +72,8 @@ type Field struct {
} }
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error
field := &Field{ field := &Field{
Name: fieldStruct.Name, Name: fieldStruct.Name,
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
@ -100,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var getRealFieldValue func(reflect.Value) var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) { getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v) rv := reflect.Indirect(v)
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
for i := 0; i < rv.Type().NumField(); i++ { for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type newFieldType := rv.Type().Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
@ -151,7 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if num, ok := field.TagSettings["SIZE"]; ok { if num, ok := field.TagSettings["SIZE"]; ok {
var err error
if field.Size, err = strconv.Atoi(num); err != nil { if field.Size, err = strconv.Atoi(num); err != nil {
field.Size = -1 field.Size = -1
} }
@ -177,33 +180,42 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Comment = val field.Comment = val
} }
// default value is function or null or blank (primary keys)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
switch reflect.Indirect(fieldValue).Kind() { switch reflect.Indirect(fieldValue).Kind() {
case reflect.Bool: case reflect.Bool:
field.DataType = Bool field.DataType = Bool
if field.HasDefaultValue && field.DefaultValue != "" { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err)
}
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int field.DataType = Int
if field.HasDefaultValue && field.DefaultValue != "" { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err)
}
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint field.DataType = Uint
if field.HasDefaultValue && field.DefaultValue != "" { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err)
}
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.DataType = Float field.DataType = Float
if field.HasDefaultValue && field.DefaultValue != "" { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err)
}
} }
case reflect.String: case reflect.String:
field.DataType = String field.DataType = String
isFunc := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")")
if field.HasDefaultValue && !isFunc { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
field.DefaultValueInterface = field.DefaultValue field.DefaultValueInterface = field.DefaultValue
@ -211,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Struct: case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok { if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { } else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
field.DataType = Time field.DataType = Time
@ -316,7 +328,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{} cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true) cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
schema.err = err schema.err = err
} }
@ -331,23 +343,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
} }
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
ef.DBName = prefix + ef.DBName ef.DBName = prefix + ef.DBName
} }
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { if ef.PrimaryKey {
ef.PrimaryKey = true if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { ef.PrimaryKey = true
ef.PrimaryKey = true } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
} else { ef.PrimaryKey = true
ef.PrimaryKey = false } else {
ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false ef.AutoIncrement = false
} }
if ef.DefaultValue == "" { if ef.DefaultValue == "" {
ef.HasDefaultValue = false ef.HasDefaultValue = false
}
} }
} }
@ -612,6 +626,14 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(value, string(data))
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
}
case string: case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil { if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetUint(i) field.ReflectValueOf(value).SetUint(i)
@ -671,7 +693,7 @@ func (field *Field) setupValuerAndSetter() {
case []byte: case []byte:
field.ReflectValueOf(value).SetString(string(data)) field.ReflectValueOf(value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(value).SetString(fmt.Sprint(data)) field.ReflectValueOf(value).SetString(utils.ToString(data))
case float64, float32: case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:

View File

@ -14,7 +14,7 @@ import (
type Namer interface { type Namer interface {
TableName(table string) string TableName(table string) string
ColumnName(table, column string) string ColumnName(table, column string) string
JoinTableName(table string) string JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string RelationshipFKName(Relationship) string
CheckerName(table, column string) string CheckerName(table, column string) string
IndexName(table, column string) string IndexName(table, column string) string
@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
// JoinTableName convert string to join table name // JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string { func (ns NamingStrategy) JoinTableName(str string) string {
if strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable { if ns.SingularTable {
return ns.TablePrefix + toDBName(str) return ns.TablePrefix + toDBName(str)
} }
@ -49,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string {
// RelationshipFKName generate fk name for relation // RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1)
} }
// CheckerName generate checker name // CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string { func (ns NamingStrategy) CheckerName(table, column string) string {
return fmt.Sprintf("chk_%s_%s", table, column) return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1)
} }
// IndexName generate index name // IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string { func (ns NamingStrategy) IndexName(table, column string) string {
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
idxName = strings.Replace(idxName, ".", "_", -1)
if utf8.RuneCountInString(idxName) > 64 { if utf8.RuneCountInString(idxName) > 64 {
h := sha1.New() h := sha1.New()

View File

@ -32,3 +32,35 @@ func TestToDBName(t *testing.T) {
} }
} }
} }
func TestNamingStrategy(t *testing.T) {
var ns = NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.user_language" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

View File

@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
// use same data type for foreign keys // use same data type for foreign keys
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
}
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField, PrimaryKey: primaryKeyField,
@ -254,12 +257,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}) })
} }
joinTableFields = append(joinTableFields, reflect.StructField{
Name: schema.Name + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many) relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
relName := relation.Schema.Name relName := relation.Schema.Name
relRefName := relation.FieldSchema.Name relRefName := relation.FieldSchema.Name
@ -290,36 +299,41 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
// build references // build references
for idx, f := range relation.JoinTable.Fields { for _, f := range relation.JoinTable.Fields {
// use same data type for foreign keys if f.Creatable || f.Readable || f.Updatable {
f.DataType = fieldsMap[f.Name].DataType // use same data type for foreign keys
f.GORMDataType = fieldsMap[f.Name].GORMDataType f.DataType = fieldsMap[f.Name].DataType
relation.JoinTable.PrimaryFields[idx] = f f.GORMDataType = fieldsMap[f.Name].GORMDataType
ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
if ownPriamryField {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
})
} else {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
} }
joinRefRel.References = append(joinRefRel.References, &Reference{ relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
PrimaryKey: fieldsMap[f.Name], ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
ForeignKey: f,
if ownPriamryField {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
})
} else {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
})
}
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPriamryField,
}) })
} }
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPriamryField,
})
} }
} }
@ -428,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
// use same data type for foreign keys // use same data type for foreign keys
foreignField.DataType = primaryFields[idx].DataType foreignField.DataType = primaryFields[idx].DataType
foreignField.GORMDataType = primaryFields[idx].GORMDataType foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
}
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx], PrimaryKey: primaryFields[idx],

View File

@ -206,16 +206,39 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type User struct { type User struct {
gorm.Model gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint Refer uint
} }
checkStructRelation(t, &User{}, Relation{ checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{ References: []Reference{
{"ID", "User", "UserReferID", "user_profiles", "", true}, {"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, {"ID", "Profile", "ProfileRefer", "user_profile", "", false},
},
})
}
func TestBuildReadonlyMany2ManyRelation(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{
{"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
}, },
}) })
} }
@ -267,3 +290,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) {
}, },
) )
} }
func TestMultipleMany2Many(t *testing.T) {
type Thing struct {
ID int
}
type Person struct {
ID int
Likes []Thing `gorm:"many2many:likes"`
Dislikes []Thing `gorm:"many2many:dislikes"`
}
checkStructRelation(t, &Person{},
Relation{
Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "likes", Table: "likes"},
References: []Reference{
{"ID", "Person", "PersonID", "likes", "", true},
{"ID", "Thing", "ThingID", "likes", "", false},
},
},
Relation{
Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"},
References: []Reference{
{"ID", "Person", "PersonID", "dislikes", "", true},
{"ID", "Thing", "ThingID", "dislikes", "", false},
},
},
)
}

View File

@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if tabler, ok := modelValue.Interface().(Tabler); ok { if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName() tableName = tabler.TableName()
} }
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
@ -133,7 +136,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok { if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName) schema.DBNames = append(schema.DBNames, field.DBName)
} }
@ -219,7 +222,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, schema.err return schema, schema.err
} }

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"strings"
"sync" "sync"
"testing" "testing"
@ -194,6 +195,7 @@ func TestEmbeddedStruct(t *testing.T) {
ID int ID int
OwnerID int OwnerID int
Name string Name string
Ignored string `gorm:"-"`
} }
type Corp struct { type Corp struct {
@ -211,15 +213,87 @@ func TestEmbeddedStruct(t *testing.T) {
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
} }
for _, f := range fields { for _, f := range fields {
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
f.Creatable = true if f.Name != "Ignored" {
f.Updatable = true f.Creatable = true
f.Readable = true f.Updatable = true
f.Readable = true
}
})
}
}
type CustomizedNamingStrategy struct {
schema.NamingStrategy
}
func (ns CustomizedNamingStrategy) ColumnName(table, column string) string {
baseColumnName := ns.NamingStrategy.ColumnName(table, column)
if table == "" {
return baseColumnName
}
s := strings.Split(table, "_")
var prefix string
switch len(s) {
case 1:
prefix = s[0][:3]
case 2:
prefix = s[0][:1] + s[1][:2]
default:
prefix = s[0][:1] + s[1][:1] + s[2][:1]
}
return prefix + "_" + baseColumnName
}
func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
type CorpBase struct {
gorm.Model
OwnerID string
}
type Company struct {
ID int
OwnerID int
Name string
Ignored string `gorm:"-"`
}
type Corp struct {
CorpBase
Base Company `gorm:"embedded;embeddedPrefix:company_"`
}
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}})
if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
}
fields := []schema.Field{
{Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
{Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
{Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
}
for _, f := range fields {
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
if f.Name != "Ignored" {
f.Creatable = true
f.Updatable = true
f.Readable = true
}
}) })
} }
} }

View File

@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
return columns, queryValues return columns, queryValues
} }
} }
type embeddedNamer struct {
Table string
Namer
}

View File

@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) {
} }
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{ return []clause.Interface{SoftDeleteQueryClause{Field: f}}
clause.Where{Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
Value: nil,
},
}},
}
} }
type SoftDeleteQueryClause struct { type SoftDeleteQueryClause struct {

View File

@ -12,6 +12,7 @@ import (
"sync" "sync"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
writer.WriteString("(NULL)") writer.WriteString("(NULL)")
} }
case *DB: case *DB:
subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance()
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
subdb.callbacks.Query().Execute(subdb) subdb.callbacks.Query().Execute(subdb)
writer.WriteString(subdb.Statement.SQL.String()) writer.WriteString(subdb.Statement.SQL.String())
@ -298,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
values := make([]interface{}, reflectValue.Len()) if _, ok := v[key].(driver.Valuer); ok {
for i := 0; i < reflectValue.Len(); i++ { conds = append(conds, clause.Eq{Column: key, Value: v[key]})
values[i] = reflectValue.Index(i).Interface() } else if _, ok := v[key].(Valuer); ok {
} conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else {
values := make([]interface{}, reflectValue.Len())
for i := 0; i < reflectValue.Len(); i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values}) conds = append(conds, clause.IN{Column: key, Values: values})
}
default: default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} }
@ -317,9 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
if field.Readable { if field.Readable {
if v, isZero := field.ValueOf(reflectValue); !isZero { if v, isZero := field.ValueOf(reflectValue); !isZero {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -330,9 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
if field.Readable { if field.Readable {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
} }
} }
} }

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"errors"
"testing" "testing"
"time" "time"
@ -287,6 +288,30 @@ func TestCreateEmptyStruct(t *testing.T) {
} }
} }
func TestCreateEmptySlice(t *testing.T) {
var data = []User{}
if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice {
t.Errorf("no data should be created, got %v", err)
}
var sliceMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice {
t.Errorf("no data should be created, got %v", err)
}
}
func TestCreateInvalidSlice(t *testing.T) {
users := []*User{
GetUser("invalid_slice_1", Config{}),
GetUser("invalid_slice_2", Config{}),
nil,
}
if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) {
t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
}
}
func TestCreateWithExistingTimestamp(t *testing.T) { func TestCreateWithExistingTimestamp(t *testing.T) {
user := User{Name: "CreateUserExistingTimestamp"} user := User{Name: "CreateUserExistingTimestamp"}
curTime := now.MustParse("2016-01-01") curTime := now.MustParse("2016-01-01")

View File

@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) {
FieldAllowSave3 string `gorm:"->:false;<-:create"` FieldAllowSave3 string `gorm:"->:false;<-:create"`
FieldReadonly string `gorm:"->"` FieldReadonly string `gorm:"->"`
FieldIgnore string `gorm:"-"` FieldIgnore string `gorm:"-"`
AutoUnixCreateTime int64 `gorm:"autocreatetime"` AutoUnixCreateTime int32 `gorm:"autocreatetime"`
AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"`
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"`
AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"`
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"`
} }
DB.Migrator().DropTable(&CustomizeFieldStruct{}) DB.Migrator().DropTable(&CustomizeFieldStruct{})
@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) {
t.Fatalf("invalid result: %#v", result) t.Fatalf("invalid result: %#v", result)
} }
if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 {
t.Fatalf("invalid create/update unix time: %#v", result) t.Fatalf("invalid create/update unix time: %#v", result)
} }
if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 {
t.Fatalf("invalid create/update unix milli time: %#v", result) t.Fatalf("invalid create/update unix milli time: %#v", result)
} }
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 {
t.Fatalf("invalid create/update unix nano time: %#v", result) t.Fatalf("invalid create/update unix nano time: %#v", result)
} }
@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) {
var createWithDefaultTimeResult CustomizeFieldStruct var createWithDefaultTimeResult CustomizeFieldStruct
DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name)
if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
} }
if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult)
} }
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
} }
} }

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) {
t.Errorf("should returns no error while enable global update, but got err %v", err) t.Errorf("should returns no error while enable global update, but got err %v", err)
} }
} }
func TestDeleteWithAssociations(t *testing.T) {
user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1})
if err := DB.Create(user).Error; err != nil {
t.Fatalf("failed to create user, got error %v", err)
}
if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil {
t.Fatalf("failed to delete user, got error %v", err)
}
for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} {
if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} {
if count := DB.Model(&user).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
}
func TestDeleteSliceWithAssociations(t *testing.T) {
users := []User{
*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}),
*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}),
*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}),
*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}),
}
if err := DB.Create(users).Error; err != nil {
t.Fatalf("failed to create user, got error %v", err)
}
if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil {
t.Fatalf("failed to delete user, got error %v", err)
}
for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} {
if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} {
if count := DB.Model(&users).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
}

View File

@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) {
DB.Migrator().DropTable(&AdvancedUser{}) DB.Migrator().DropTable(&AdvancedUser{})
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
t.Errorf("Failed to auto migrate advanced user, got error %v", err) if DB.Dialector.Name() != "sqlite" {
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
}
} }
} }

View File

@ -7,11 +7,13 @@ require (
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
github.com/stevefan1999-personal/gorm-driver-oracle latest github.com/stevefan1999-personal/gorm-driver-oracle latest
gorm.io/driver/mysql v1.0.0 gorm.io/driver/mysql v1.0.1
gorm.io/driver/postgres v1.0.0 gorm.io/driver/postgres v1.0.2
gorm.io/driver/sqlite v1.1.0 gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.1 gorm.io/driver/sqlserver v1.0.4
gorm.io/gorm v1.9.19 gorm.io/gorm v1.20.2
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../
replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2

View File

@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Pets", func(t *testing.T) { t.Run("Pets", func(t *testing.T) {
if len(user.Pets) != len(expect.Pets) { if len(user.Pets) != len(expect.Pets) {
t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
} }
sort.Slice(user.Pets, func(i, j int) bool { sort.Slice(user.Pets, func(i, j int) bool {
@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Toys", func(t *testing.T) { t.Run("Toys", func(t *testing.T) {
if len(user.Toys) != len(expect.Toys) { if len(user.Toys) != len(expect.Toys) {
t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
} }
sort.Slice(user.Toys, func(i, j int) bool { sort.Slice(user.Toys, func(i, j int) bool {
@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Team", func(t *testing.T) { t.Run("Team", func(t *testing.T) {
if len(user.Team) != len(expect.Team) { if len(user.Team) != len(expect.Team) {
t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
} }
sort.Slice(user.Team, func(i, j int) bool { sort.Slice(user.Team, func(i, j int) bool {
@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Languages", func(t *testing.T) { t.Run("Languages", func(t *testing.T) {
if len(user.Languages) != len(expect.Languages) { if len(user.Languages) != len(expect.Languages) {
t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
} }
sort.Slice(user.Languages, func(i, j int) bool { sort.Slice(user.Languages, func(i, j int) bool {
@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) {
t.Run("Friends", func(t *testing.T) { t.Run("Friends", func(t *testing.T) {
if len(user.Friends) != len(expect.Friends) { if len(user.Friends) != len(expect.Friends) {
t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
} }
sort.Slice(user.Friends, func(i, j int) bool { sort.Slice(user.Friends, func(i, j int) bool {

View File

@ -5,12 +5,14 @@ import (
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type Person struct { type Person struct {
ID int ID int
Name string Name string
Addresses []Address `gorm:"many2many:person_addresses;"` Addresses []Address `gorm:"many2many:person_addresses;"`
DeletedAt gorm.DeletedAt
} }
type Address struct { type Address struct {
@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) {
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
t.Fatalf("address should be deleted when clear with unscoped") t.Fatalf("address should be deleted when clear with unscoped")
} }
address2_1 := Address{Name: "address 2-1"}
address2_2 := Address{Name: "address 2-2"}
person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}}
DB.Create(&person2)
if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil {
t.Fatalf("failed to delete person, got error: %v", err)
}
if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 {
t.Errorf("person's addresses expects 2, got %v", count)
}
if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 {
t.Errorf("person's addresses expects 2, got %v", count)
}
} }

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
) )
type Blog struct { type Blog struct {
@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
t.Fatalf("EN Blog's tags should be cleared") t.Fatalf("EN Blog's tags should be cleared")
} }
} }
func TestCompositePrimaryKeysAssociations(t *testing.T) {
type Label struct {
BookID *uint `gorm:"primarykey"`
Name string `gorm:"primarykey"`
Value string
}
type Book struct {
ID int
Name string
Labels []Label
}
DB.Migrator().DropTable(&Label{}, &Book{})
if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil {
t.Fatalf("failed to migrate")
}
book := Book{
Name: "my book",
Labels: []Label{
{Name: "region", Value: "emea"},
},
}
DB.Create(&book)
var result Book
if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil {
t.Fatalf("failed to preload, got error %v", err)
}
AssertEqual(t, book, result)
}

View File

@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) {
t.Errorf("failed to update with named arg") t.Errorf("failed to update with named arg")
} }
namedUser.Name1 = "jinzhu-new"
namedUser.Name2 = "jinzhu-new2"
namedUser.Name3 = "jinzhu-new"
var result5 NamedUser var result5 NamedUser
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil {
t.Errorf("failed to update with named arg") t.Errorf("failed to update with named arg")
} }
AssertEqual(t, result4, namedUser) AssertEqual(t, result5, namedUser)
var result6 NamedUser
if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name
AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil {
t.Errorf("failed to update with named arg")
}
AssertEqual(t, result6, namedUser)
} }

View File

@ -1,6 +1,8 @@
package tests_test package tests_test
import ( import (
"encoding/json"
"regexp"
"sort" "sort"
"strconv" "strconv"
"testing" "testing"
@ -31,6 +33,20 @@ func TestPreloadWithAssociations(t *testing.T) {
var user2 User var user2 User
DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
var user3 = *GetUser("preload_with_associations_new", Config{
Account: true,
Pets: 2,
Toys: 3,
Company: true,
Manager: true,
Team: 4,
Languages: 3,
Friends: 1,
})
DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID)
CheckUser(t, user3, user)
} }
func TestNestedPreload(t *testing.T) { func TestNestedPreload(t *testing.T) {
@ -174,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) {
CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2])
} }
} }
func TestPreloadEmptyData(t *testing.T) {
var user = *GetUser("user_without_associations", Config{})
DB.Create(&user)
DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name)
if r, err := json.Marshal(&user); err != nil {
t.Errorf("failed to marshal users, got error %v", err)
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
var results []User
DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name)
if r, err := json.Marshal(&results); err != nil {
t.Errorf("failed to marshal users, got error %v", err)
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
}

View File

@ -0,0 +1,52 @@
package tests_test
import (
"context"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestPreparedStmt(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
txCtx := tx.WithContext(ctx)
user := *GetUser("prepared_stmt", Config{})
txCtx.Create(&user)
var result1 User
if err := txCtx.Find(&result1, user.ID).Error; err != nil {
t.Fatalf("no error should happen but got %v", err)
}
time.Sleep(time.Second)
var result2 User
if err := tx.Find(&result2, user.ID).Error; err != nil {
t.Fatalf("no error should happen but got %v", err)
}
user2 := *GetUser("prepared_stmt2", Config{})
if err := txCtx.Create(&user2).Error; err == nil {
t.Fatalf("should failed to create with timeout context")
}
if err := tx.Create(&user2).Error; err != nil {
t.Fatalf("failed to create, got error %v", err)
}
var result3 User
if err := tx.Find(&result3, user2.ID).Error; err != nil {
t.Fatalf("no error should happen but got %v", err)
}
}

View File

@ -1,11 +1,13 @@
package tests_test package tests_test
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -61,6 +63,54 @@ func TestFind(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} { for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name) dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := first[dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Age":
if _, ok := first[dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Birthday":
if _, ok := first[dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
}
}
})
t.Run("FirstMapWithTable", func(t *testing.T) {
var first = map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(first[dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
}) })
@ -86,13 +136,29 @@ func TestFind(t *testing.T) {
t.Run("FirstSliceOfMap", func(t *testing.T) { t.Run("FirstSliceOfMap", func(t *testing.T) {
var allMap = []map[string]interface{}{} var allMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query first: %v", err) t.Errorf("errors happened when query find: %v", err)
} else { } else {
for idx, user := range users { for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} { for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name) dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := allMap[idx][dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Age":
if _, ok := allMap[idx][dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Birthday":
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user)) reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
}) })
@ -101,6 +167,58 @@ func TestFind(t *testing.T) {
} }
} }
}) })
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query find: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
}
})
}
}
})
var models []User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models))
} else {
for idx, user := range users {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models[idx], user)
})
}
}
var none []User
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
}
} }
func TestQueryWithAssociation(t *testing.T) { func TestQueryWithAssociation(t *testing.T) {
@ -228,6 +346,11 @@ func TestNot(t *testing.T) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
} }
result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) result = dryDB.Not("name = ?", "jinzhu").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
@ -309,6 +432,33 @@ func TestPluck(t *testing.T) {
t.Errorf("Unexpected result on pluck id, got %+v", ids) t.Errorf("Unexpected result on pluck id, got %+v", ids)
} }
} }
var times []time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &times).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range times {
AssertEqual(t, tv, users[idx].CreatedAt)
}
var ptrtimes []*time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range ptrtimes {
AssertEqual(t, tv, users[idx].CreatedAt)
}
var nulltimes []sql.NullTime
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range nulltimes {
AssertEqual(t, tv.Time, users[idx].CreatedAt)
}
} }
func TestSelect(t *testing.T) { func TestSelect(t *testing.T) {
@ -508,6 +658,7 @@ func TestLimit(t *testing.T) {
{Name: "LimitUser3", Age: 20}, {Name: "LimitUser3", Age: 20},
{Name: "LimitUser4", Age: 10}, {Name: "LimitUser4", Age: 10},
{Name: "LimitUser5", Age: 20}, {Name: "LimitUser5", Age: 20},
{Name: "LimitUser6", Age: 20},
} }
DB.Create(&users) DB.Create(&users)
@ -516,7 +667,7 @@ func TestLimit(t *testing.T) {
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
t.Errorf("Limit should works") t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3))
} }
} }
@ -531,6 +682,7 @@ func TestOffset(t *testing.T) {
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
t.Errorf("Offset should work") t.Errorf("Offset should work")
} }
DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
@ -698,3 +850,11 @@ func TestScanNullValue(t *testing.T) {
t.Fatalf("failed to query slice data with null age, got error %v", err) t.Fatalf("failed to query slice data with null age, got error %v", err)
} }
} }
func TestQueryWithTableAndConditions(t *testing.T) {
result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"})
if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
}

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -16,14 +17,25 @@ func TestScan(t *testing.T) {
DB.Save(&user1).Save(&user2).Save(&user3) DB.Save(&user1).Save(&user2).Save(&user3)
type result struct { type result struct {
ID uint
Name string Name string
Age int Age int
} }
var res result var res result
DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res)
if res.Name != user3.Name || res.Age != int(user3.Age) { if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
t.Errorf("Scan into struct should work") t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
}
DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res)
if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2)
}
DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res)
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
} }
var doubleAgeRes = &result{} var doubleAgeRes = &result{}
@ -39,11 +51,11 @@ func TestScan(t *testing.T) {
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
sort.Slice(results, func(i, j int) bool { sort.Slice(results, func(i, j int) bool {
return strings.Compare(results[i].Name, results[j].Name) < -1 return strings.Compare(results[i].Name, results[j].Name) <= -1
}) })
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
t.Errorf("Scan into struct map") t.Errorf("Scan into struct map, got %#v", results)
} }
} }
@ -72,7 +84,21 @@ func TestScanRows(t *testing.T) {
results = append(results, result) results = append(results, result)
} }
sort.Slice(results, func(i, j int) bool {
return strings.Compare(results[i].Name, results[j].Name) <= -1
})
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
t.Errorf("Should find expected results") t.Errorf("Should find expected results")
} }
var ages int
if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages)
}
var name string
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
}
} }

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"regexp"
"strings" "strings"
"testing" "testing"
@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) {
t.Errorf("expects: %v, got %v", expects, result) t.Errorf("expects: %v, got %v", expects, result)
} }
} }
func TestCombineStringConditions(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
}

View File

@ -13,7 +13,7 @@ if [ -d tests ]
then then
cd tests cd tests
cp go.mod go.mod.bak cp go.mod go.mod.bak
sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod
cd .. cd ..
fi fi

View File

@ -22,7 +22,7 @@ var DB *gorm.DB
func init() { func init() {
var err error var err error
if DB, err = OpenTestConnection(); err != nil { if DB, err = OpenTestConnection(); err != nil {
log.Printf("failed to connect database, got error %v\n", err) log.Printf("failed to connect database, got error %v", err)
os.Exit(1) os.Exit(1)
} else { } else {
sqlDB, err := DB.DB() sqlDB, err := DB.DB()
@ -31,7 +31,7 @@ func init() {
} }
if err != nil { if err != nil {
log.Printf("failed to connect database, got error %v\n", err) log.Printf("failed to connect database, got error %v", err)
} }
RunMigrations() RunMigrations()

View File

@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) {
t.Fatalf("Should find saved record") t.Fatalf("Should find saved record")
} }
} }
func TestTransactionOnClosedConn(t *testing.T) {
DB, err := OpenTestConnection()
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
rawDB, _ := DB.DB()
rawDB.Close()
if err := DB.Transaction(func(tx *gorm.DB) error {
return nil
}); err == nil {
t.Errorf("should returns error when commit with closed conn, got error %v", err)
}
if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error {
return nil
}); err == nil {
t.Errorf("should returns error when commit with closed conn, got error %v", err)
}
}

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
var user2 User var user2 User
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
user.Company.Name += "new"
user.Manager.Name += "new"
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
} }

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
DB.Preload("Pets").Find(&user2, "id = ?", user.ID) DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
for _, pet := range user.Pets {
pet.Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
t.Run("Polymorphic", func(t *testing.T) { t.Run("Polymorphic", func(t *testing.T) {
var user = *GetUser("update-has-many", Config{}) var user = *GetUser("update-has-many", Config{})
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
var user2 User var user2 User
DB.Preload("Toys").Find(&user2, "id = ?", user.ID) DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
for idx := range user.Toys {
user.Toys[idx].Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
}) })
} }

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
DB.Preload("Account").Find(&user2, "id = ?", user.ID) DB.Preload("Account").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
user.Account.Number += "new"
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
t.Run("Polymorphic", func(t *testing.T) { t.Run("Polymorphic", func(t *testing.T) {
var pet = Pet{Name: "create"} var pet = Pet{Name: "create"}
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
var pet2 Pet var pet2 Pet
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
CheckPet(t, pet2, pet) CheckPet(t, pet2, pet)
pet.Toy.Name += "new"
if err := DB.Save(&pet).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var pet3 Pet
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
CheckPet(t, pet2, pet3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var pet4 Pet
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
CheckPet(t, pet4, pet)
}) })
} }

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
var user2 User var user2 User
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
for idx := range user.Friends {
user.Friends[idx].Name += "new"
}
for idx := range user.Languages {
user.Languages[idx].Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
} }

View File

@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) {
var result2 Language var result2 Language
DB.First(&result2, "code = ?", "save") DB.First(&result2, "code = ?", "save")
AssertEqual(t, result2, lang) AssertEqual(t, result2, lang)
DB.Table("langs").Migrator().DropTable(&Language{})
DB.Table("langs").AutoMigrate(&Language{})
if err := DB.Table("langs").Save(&lang).Error; err != nil {
t.Errorf("no error should happen when creating data, but got %v", err)
}
var result3 Language
if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name {
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3)
}
lang.Name += "name2"
if err := DB.Table("langs").Save(&lang).Error; err != nil {
t.Errorf("no error should happen when creating data, but got %v", err)
}
var result4 Language
if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name {
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
}
} }

View File

@ -30,7 +30,7 @@ func FileWithLineNum() string {
} }
func IsValidDBNameChar(c rune) bool { func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
} }
func CheckTruth(val interface{}) bool { func CheckTruth(val interface{}) bool {
@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool {
} }
return true return true
} }
func ToString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case int:
return strconv.FormatInt(int64(v), 10)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case uint:
return strconv.FormatUint(uint64(v), 10)
case uint8:
return strconv.FormatUint(uint64(v), 10)
case uint16:
return strconv.FormatUint(uint64(v), 10)
case uint32:
return strconv.FormatUint(uint64(v), 10)
case uint64:
return strconv.FormatUint(v, 10)
}
return ""
}