Merge branch 'master' into patch-oracle
This commit is contained in:
commit
b2fdc25f73
5
.github/labels.json
vendored
5
.github/labels.json
vendored
@ -10,6 +10,11 @@
|
||||
"colour": "#EDEDED",
|
||||
"description": "general questions"
|
||||
},
|
||||
"invalid_question": {
|
||||
"name": "type:invalid question",
|
||||
"colour": "#CF2E1F",
|
||||
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
|
||||
},
|
||||
"with_playground": {
|
||||
"name": "type:with reproduction steps",
|
||||
"colour": "#00ff00",
|
||||
|
22
.github/workflows/invalid_question.yml
vendored
Normal file
22
.github/workflows/invalid_question.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: "Close invalid questions issues"
|
||||
on:
|
||||
schedule:
|
||||
- cron: "*/10 * * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v3.0.7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
stale-issue-label: "status:stale"
|
||||
days-before-stale: 0
|
||||
days-before-close: 2
|
||||
remove-stale-when-updated: true
|
||||
only-labels: "type:invalid question"
|
||||
|
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
uses: actions/stale@v3.0.7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs."
|
||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
stale-issue-label: "status:stale"
|
||||
days-before-stale: 0
|
||||
days-before-close: 2
|
||||
|
16
callbacks.go
16
callbacks.go
@ -8,7 +8,6 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
@ -75,17 +74,22 @@ func (cs *callbacks) Raw() *processor {
|
||||
func (p *processor) Execute(db *DB) {
|
||||
curTime := time.Now()
|
||||
stmt := db.Statement
|
||||
db.RowsAffected = 0
|
||||
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
} else if stmt.Dest == nil {
|
||||
stmt.Dest = stmt.Model
|
||||
}
|
||||
|
||||
if stmt.Model != nil {
|
||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
||||
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
|
||||
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Dest != nil {
|
||||
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
||||
@ -154,7 +158,7 @@ func (p *processor) compile() (err error) {
|
||||
p.callbacks = callbacks
|
||||
|
||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||
logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -177,7 +181,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
|
||||
}
|
||||
|
||||
func (c *callback) Remove(name string) error {
|
||||
logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.remove = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
@ -185,7 +189,7 @@ func (c *callback) Remove(name string) error {
|
||||
}
|
||||
|
||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||
logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.replace = true
|
||||
@ -215,7 +219,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||
logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||
}
|
||||
names = append(names, c.name)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
@ -65,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
DoNothing: true,
|
||||
}).Create(elems.Interface()).Error) == nil {
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
@ -80,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
DoNothing: true,
|
||||
}).Create(rv.Interface()).Error) == nil {
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
@ -144,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
|
||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
||||
}).Create(elems.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||
).Create(elems.Interface()).Error)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
@ -167,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
|
||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
||||
}).Create(f.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||
).Create(f.Interface()).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -229,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
|
||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
||||
}).Create(elems.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||
).Create(elems.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -297,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
|
||||
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
appendToJoins(objs[i], elems.Index(i))
|
||||
@ -310,3 +304,32 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
|
||||
if stmt.DB.FullSaveAssociations {
|
||||
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
|
||||
for _, dbName := range s.DBNames {
|
||||
if !s.LookUpField(dbName).PrimaryKey {
|
||||
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(defaultUpdatingColumns) > 0 {
|
||||
var columns []clause.Column
|
||||
if s.PrioritizedPrimaryField != nil {
|
||||
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
|
||||
} else {
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
columns = append(columns, clause.Column{Name: dbName})
|
||||
}
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: columns,
|
||||
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
|
||||
}
|
||||
}
|
||||
|
||||
return clause.OnConflict{DoNothing: true}
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
deleteCallback := db.Callback().Delete()
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||
deleteCallback.Register("gorm:delete", Delete)
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -87,9 +88,12 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
case reflect.Struct:
|
||||
if insertID > 0 {
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
@ -252,8 +256,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
stmt.SQL.Grow(stmt.ReflectValue.Len() * 15)
|
||||
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
if stmt.ReflectValue.Len() == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
|
@ -2,6 +2,7 @@ package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -21,6 +22,85 @@ func BeforeDelete(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
|
||||
if restricted {
|
||||
for column, v := range selectColumns {
|
||||
if v {
|
||||
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
|
||||
switch rel.Type {
|
||||
case schema.HasOne, schema.HasMany:
|
||||
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
|
||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := db.Session(&gorm.Session{}).Model(modelValue)
|
||||
withoutConditions := false
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
var selects []string
|
||||
for _, s := range db.Statement.Selects {
|
||||
if s == clause.Associations {
|
||||
selects = append(selects, s)
|
||||
} else if strings.HasPrefix(s, column+".") {
|
||||
selects = append(selects, strings.TrimPrefix(s, column+"."))
|
||||
}
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cond := range queryConds {
|
||||
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
|
||||
withoutConditions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !withoutConditions {
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
queryConds []clause.Expression
|
||||
foreignFields []*schema.Field
|
||||
relForeignKeys []string
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
table = rel.JoinTable.Table
|
||||
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
queryConds = append(queryConds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
|
||||
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Delete(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
|
@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
)
|
||||
|
||||
if len(mapValues) == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
for k, v := range mapValue {
|
||||
if stmt.Schema != nil {
|
||||
|
@ -107,6 +107,26 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
||||
// clean up old values before preloading
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
|
||||
default:
|
||||
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
|
||||
default:
|
||||
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
elem := reflectResults.Index(i)
|
||||
for idx, field := range relForeignFields {
|
||||
|
@ -11,11 +11,13 @@ func RowQuery(db *gorm.DB) {
|
||||
}
|
||||
|
||||
if !db.DryRun {
|
||||
if _, ok := db.Get("rows"); ok {
|
||||
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
|
||||
db.RowsAffected = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
|
||||
// Order specify order when retrieve records from database
|
||||
// db.Order("name DESC")
|
||||
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
|
@ -3,6 +3,7 @@ package clause
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
@ -30,19 +31,23 @@ func (expr Expr) Build(builder Builder) {
|
||||
)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '?' {
|
||||
if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
@ -85,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) {
|
||||
for k, v := range value {
|
||||
namedMap[k] = v
|
||||
}
|
||||
default:
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
modelType := reflectValue.Type()
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,7 +110,7 @@ func (expr NamedExpr) Build(builder Builder) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = []byte{}
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' {
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
@ -106,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) {
|
||||
}
|
||||
|
||||
builder.WriteByte(v)
|
||||
} else if v == '?' {
|
||||
} else if v == '?' && len(expr.Vars) > idx {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
idx++
|
||||
} else if inName {
|
||||
|
@ -37,6 +37,11 @@ func TestExpr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNamedExpr(t *testing.T) {
|
||||
type NamedArgument struct {
|
||||
Name1 string
|
||||
Name2 string
|
||||
}
|
||||
|
||||
results := []struct {
|
||||
SQL string
|
||||
Result string
|
||||
@ -66,6 +71,15 @@ func TestNamedExpr(t *testing.T) {
|
||||
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||
}, {
|
||||
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
|
||||
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}},
|
||||
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||
}, {
|
||||
SQL: "create table ? (? ?, ? ?)",
|
||||
Vars: []interface{}{},
|
||||
Result: "create table ? (? ?, ? ?)",
|
||||
}}
|
||||
|
||||
for idx, result := range results {
|
||||
|
@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if limit.Limit == 0 && v.Limit > 0 {
|
||||
if limit.Limit == 0 && v.Limit != 0 {
|
||||
limit.Limit = v.Limit
|
||||
} else if limit.Limit < 0 {
|
||||
limit.Limit = 0
|
||||
}
|
||||
|
||||
if limit.Offset == 0 && v.Offset > 0 {
|
||||
|
@ -1,5 +1,9 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
Exprs []Expression
|
||||
@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) {
|
||||
}
|
||||
}
|
||||
|
||||
wrapInParentheses := false
|
||||
for idx, expr := range where.Exprs {
|
||||
if idx > 0 {
|
||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||
@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(where.Exprs) > 1 {
|
||||
switch v := expr.(type) {
|
||||
case OrConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
case AndConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
case Expr:
|
||||
sql := strings.ToLower(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteString(`(`)
|
||||
expr.Build(builder)
|
||||
builder.WriteString(`)`)
|
||||
wrapInParentheses = false
|
||||
} else {
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) {
|
||||
func And(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
} else if len(exprs) == 1 {
|
||||
return exprs[0]
|
||||
}
|
||||
return AndConditions{Exprs: exprs}
|
||||
}
|
||||
@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
|
@ -27,4 +27,8 @@ var (
|
||||
ErrRegistered = errors.New("registered")
|
||||
// ErrInvalidField invalid field
|
||||
ErrInvalidField = errors.New("invalid field")
|
||||
// ErrEmptySlice empty slice found
|
||||
ErrEmptySlice = errors.New("empty slice found")
|
||||
// ErrDryRunModeUnsupported dry run mode unsupported
|
||||
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
|
||||
)
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
@ -51,7 +52,8 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||
if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
return tx.Create(value)
|
||||
}
|
||||
}
|
||||
@ -331,22 +333,43 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
tx := db.getInstance()
|
||||
tx := db.getInstance().InstanceSet("rows", false)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
return tx.Statement.Dest.(*sql.Row)
|
||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||
if !ok && tx.DryRun {
|
||||
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
tx := db.Set("rows", true)
|
||||
tx := db.getInstance().InstanceSet("rows", true)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
return tx.Statement.Dest.(*sql.Rows), tx.Error
|
||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||
if !ok && tx.DryRun && tx.Error == nil {
|
||||
tx.Error = ErrDryRunModeUnsupported
|
||||
}
|
||||
return rows, tx.Error
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
currentLogger, newLogger := db.Logger, logger.Recorder.New()
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
tx.Logger = newLogger
|
||||
if rows, err := tx.Rows(); err != nil {
|
||||
tx.AddError(err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
tx.ScanRows(rows, dest)
|
||||
}
|
||||
}
|
||||
|
||||
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
|
||||
return newLogger.SQL, tx.RowsAffected
|
||||
}, tx.Error)
|
||||
tx.Logger = currentLogger
|
||||
return
|
||||
}
|
||||
|
||||
@ -377,9 +400,14 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
tx := db.getInstance()
|
||||
tx.Error = tx.Statement.Parse(dest)
|
||||
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||
tx.AddError(err)
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest))
|
||||
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
|
||||
}
|
||||
Scan(rows, tx, true)
|
||||
return tx.Error
|
||||
}
|
||||
|
13
gorm.go
13
gorm.go
@ -20,6 +20,8 @@ type Config struct {
|
||||
SkipDefaultTransaction bool
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
FullSaveAssociations bool
|
||||
// Logger
|
||||
Logger logger.Interface
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
@ -64,6 +66,7 @@ type Session struct {
|
||||
WithConditions bool
|
||||
SkipDefaultTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||
txConfig.AllowGlobalUpdate = true
|
||||
}
|
||||
|
||||
if config.FullSaveAssociations {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
@ -169,12 +176,15 @@ func (db *DB) Session(config *Session) *DB {
|
||||
|
||||
if config.PrepareStmt {
|
||||
if v, ok := db.cacheStore.Load("preparedStmt"); ok {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
preparedStmt := v.(*PreparedStmtDB)
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -316,6 +326,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||
f.DataType = ref.ForeignKey.DataType
|
||||
f.GORMDataType = ref.ForeignKey.GORMDataType
|
||||
if f.Size == 0 {
|
||||
f.Size = ref.ForeignKey.Size
|
||||
}
|
||||
ref.ForeignKey = f
|
||||
} else {
|
||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||
|
@ -2,6 +2,7 @@ package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
@ -19,6 +20,7 @@ const (
|
||||
Magenta = "\033[35m"
|
||||
Cyan = "\033[36m"
|
||||
White = "\033[37m"
|
||||
BlueBold = "\033[34;1m"
|
||||
MagentaBold = "\033[35;1m"
|
||||
RedBold = "\033[31;1m"
|
||||
YellowBold = "\033[33;1m"
|
||||
@ -54,29 +56,33 @@ type Interface interface {
|
||||
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
|
||||
}
|
||||
|
||||
var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
var (
|
||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
SlowThreshold: 100 * time.Millisecond,
|
||||
LogLevel: Warn,
|
||||
Colorful: true,
|
||||
})
|
||||
Recorder = traceRecorder{Interface: Default}
|
||||
)
|
||||
|
||||
func New(writer Writer, config Config) Interface {
|
||||
var (
|
||||
infoStr = "%s\n[info] "
|
||||
warnStr = "%s\n[warn] "
|
||||
errStr = "%s\n[error] "
|
||||
traceStr = "%s\n[%v] [rows:%d] %s"
|
||||
traceWarnStr = "%s\n[%v] [rows:%d] %s"
|
||||
traceErrStr = "%s %s\n[%v] [rows:%d] %s"
|
||||
traceStr = "%s\n[%.3fms] [rows:%v] %s"
|
||||
traceWarnStr = "%s\n[%.3fms] [rows:%v] %s"
|
||||
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
|
||||
)
|
||||
|
||||
if config.Colorful {
|
||||
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
||||
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
|
||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
|
||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||
}
|
||||
|
||||
return &logger{
|
||||
@ -133,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
|
||||
switch {
|
||||
case err != nil && l.LogLevel >= Error:
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
case l.LogLevel >= Info:
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type traceRecorder struct {
|
||||
Interface
|
||||
BeginAt time.Time
|
||||
SQL string
|
||||
RowsAffected int64
|
||||
Err error
|
||||
}
|
||||
|
||||
func (l traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface}
|
||||
}
|
||||
|
||||
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
l.BeginAt = begin
|
||||
l.SQL, l.RowsAffected = fc()
|
||||
l.Err = err
|
||||
}
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func isPrintable(s []byte) bool {
|
||||
@ -24,19 +26,38 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
|
||||
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var convertParams func(interface{}, int)
|
||||
var vars = make([]interface{}, len(avars))
|
||||
copy(vars, avars)
|
||||
var vars = make([]string, len(avars))
|
||||
|
||||
convertParams = func(v interface{}, idx int) {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
vars[idx] = fmt.Sprint(v)
|
||||
vars[idx] = strconv.FormatBool(v)
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
|
||||
}
|
||||
case *time.Time:
|
||||
if v != nil {
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
|
||||
}
|
||||
} else {
|
||||
vars[idx] = "NULL"
|
||||
}
|
||||
case fmt.Stringer:
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||
case driver.Valuer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
r, _ := v.Value()
|
||||
convertParams(r, idx)
|
||||
} else {
|
||||
vars[idx] = "NULL"
|
||||
}
|
||||
case []byte:
|
||||
if isPrintable(v) {
|
||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||
@ -44,7 +65,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
vars[idx] = fmt.Sprintf("%d", v)
|
||||
vars[idx] = utils.ToString(v)
|
||||
case float64, float32:
|
||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||
case string:
|
||||
@ -70,18 +91,30 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
}
|
||||
}
|
||||
|
||||
for idx, v := range vars {
|
||||
for idx, v := range avars {
|
||||
convertParams(v, idx)
|
||||
}
|
||||
|
||||
if numericPlaceholder == nil {
|
||||
for _, v := range vars {
|
||||
sql = strings.Replace(sql, "?", v.(string), 1)
|
||||
var idx int
|
||||
var newSQL strings.Builder
|
||||
|
||||
for _, v := range []byte(sql) {
|
||||
if v == '?' {
|
||||
if len(vars) > idx {
|
||||
newSQL.WriteString(vars[idx])
|
||||
idx++
|
||||
continue
|
||||
}
|
||||
}
|
||||
newSQL.WriteByte(v)
|
||||
}
|
||||
|
||||
sql = newSQL.String()
|
||||
} else {
|
||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||
for idx, v := range vars {
|
||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
|
||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,39 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type JSON json.RawMessage
|
||||
|
||||
func (j JSON) Value() (driver.Value, error) {
|
||||
if len(j) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return json.RawMessage(j).MarshalJSON()
|
||||
}
|
||||
|
||||
type ExampleStruct struct {
|
||||
Name string
|
||||
Val string
|
||||
}
|
||||
|
||||
func (s ExampleStruct) Value() (driver.Value, error) {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
func format(v []byte, escaper string) string {
|
||||
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||
}
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
type role string
|
||||
type password []byte
|
||||
@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) {
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password([]byte("pass"))
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
)
|
||||
|
||||
results := []struct {
|
||||
@ -29,6 +59,12 @@ func TestExplainSQL(t *testing.T) {
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||
@ -47,6 +83,18 @@ func TestExplainSQL(t *testing.T) {
|
||||
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
}
|
||||
|
||||
for idx, r := range results {
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
|
||||
// Migrator returns migrator
|
||||
func (db *DB) Migrator() Migrator {
|
||||
return db.Dialector.Migrator(db)
|
||||
return db.Dialector.Migrator(db.Session(&Session{WithConditions: true}))
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models
|
||||
|
@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
|
||||
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
|
||||
} else {
|
||||
} else if field.DefaultValue != "(-)" {
|
||||
expr.SQL += " DEFAULT " + field.DefaultValue
|
||||
}
|
||||
}
|
||||
@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
@ -297,10 +306,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
@ -354,9 +365,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
|
||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1)
|
||||
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
|
||||
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
@ -386,8 +397,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
|
||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
columnTypes, err = rows.ColumnTypes()
|
||||
}
|
||||
return err
|
||||
@ -584,6 +596,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
||||
var (
|
||||
modelNames, orderedModelNames []string
|
||||
orderedModelNamesMap = map[string]bool{}
|
||||
parsedSchemas = map[*schema.Schema]bool{}
|
||||
valuesMap = map[string]Dependency{}
|
||||
insertIntoOrderedList func(name string)
|
||||
parseDependence func(value interface{}, addToList bool)
|
||||
@ -593,23 +606,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
||||
dep := Dependency{
|
||||
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
||||
}
|
||||
beDependedOn := map[*schema.Schema]bool{}
|
||||
if err := dep.Parse(value); err != nil {
|
||||
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
||||
}
|
||||
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
||||
return
|
||||
}
|
||||
parsedSchemas[dep.Statement.Schema] = true
|
||||
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
if rel.Schema != rel.FieldSchema {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||
beDependedOn[rel.FieldSchema] = true
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
// append join value
|
||||
defer func(joinValue interface{}) {
|
||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||
if !beDependedOn[rel.FieldSchema] {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
} else {
|
||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
parseDependence(fieldValue, autoAdd)
|
||||
}
|
||||
parseDependence(joinValue, autoAdd)
|
||||
}(reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Unlock()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
db.Mux.RUnlock()
|
||||
@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
|
||||
stmt, err := db.ConnPool.PrepareContext(ctx, query)
|
||||
if err == nil {
|
||||
db.Stmts[query] = stmt
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := db.prepare(query)
|
||||
stmt, err := db.prepare(ctx, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := db.prepare(query)
|
||||
stmt, err := db.prepare(ctx, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := db.prepare(query)
|
||||
stmt, err := db.prepare(ctx, query)
|
||||
if err == nil {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
@ -99,10 +99,24 @@ type PreparedStmtTX struct {
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
stmt.Close()
|
||||
@ -114,7 +128,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
@ -128,9 +142,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
|
||||
return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
86
scan.go
86
scan.go
@ -2,22 +2,63 @@ package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||
for idx, column := range columns {
|
||||
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||
mapValue[column] = reflectValue.Interface()
|
||||
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||
mapValue[column], _ = valuer.Value()
|
||||
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
||||
mapValue[column] = string(b)
|
||||
}
|
||||
} else {
|
||||
mapValue[column] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
columns, _ := rows.Columns()
|
||||
values := make([]interface{}, len(columns))
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
@ -28,41 +69,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
mapValue[column] = nil
|
||||
} else {
|
||||
mapValue[column] = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue := map[string]interface{}{}
|
||||
for idx, column := range columns {
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
mapValue[column] = nil
|
||||
} else {
|
||||
mapValue[column] = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
*dest = append(*dest, mapValue)
|
||||
}
|
||||
case *int, *int64, *uint, *uint64, *float32, *float64:
|
||||
case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time:
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
@ -114,7 +136,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
}
|
||||
|
||||
// pluck values into slice of data
|
||||
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
|
||||
isPluck := false
|
||||
if len(fields) == 1 {
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok {
|
||||
isPluck = true
|
||||
} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) {
|
||||
isPluck = true
|
||||
}
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
|
@ -18,6 +18,8 @@ type DataType string
|
||||
|
||||
type TimeType int64
|
||||
|
||||
var TimeReflectType = reflect.TypeOf(time.Time{})
|
||||
|
||||
const (
|
||||
UnixSecond TimeType = 1
|
||||
UnixMillisecond TimeType = 2
|
||||
@ -70,6 +72,8 @@ type Field struct {
|
||||
}
|
||||
|
||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
var err error
|
||||
|
||||
field := &Field{
|
||||
Name: fieldStruct.Name,
|
||||
BindNames: []string{fieldStruct.Name},
|
||||
@ -100,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
var getRealFieldValue func(reflect.Value)
|
||||
getRealFieldValue = func(v reflect.Value) {
|
||||
rv := reflect.Indirect(v)
|
||||
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
||||
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
|
||||
for i := 0; i < rv.Type().NumField(); i++ {
|
||||
newFieldType := rv.Type().Field(i).Type
|
||||
for newFieldType.Kind() == reflect.Ptr {
|
||||
@ -151,7 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||
var err error
|
||||
if field.Size, err = strconv.Atoi(num); err != nil {
|
||||
field.Size = -1
|
||||
}
|
||||
@ -177,33 +180,42 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
field.Comment = val
|
||||
}
|
||||
|
||||
// default value is function or null or blank (primary keys)
|
||||
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
|
||||
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
|
||||
switch reflect.Indirect(fieldValue).Kind() {
|
||||
case reflect.Bool:
|
||||
field.DataType = Bool
|
||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
||||
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.DataType = Int
|
||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
||||
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.DataType = Uint
|
||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
||||
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.DataType = Float
|
||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
||||
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
field.DataType = String
|
||||
isFunc := strings.Contains(field.DefaultValue, "(") &&
|
||||
strings.Contains(field.DefaultValue, ")")
|
||||
|
||||
if field.HasDefaultValue && !isFunc {
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
|
||||
field.DefaultValueInterface = field.DefaultValue
|
||||
@ -211,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
case reflect.Struct:
|
||||
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
||||
field.DataType = Time
|
||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
||||
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
|
||||
field.DataType = Time
|
||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
||||
field.DataType = Time
|
||||
@ -316,7 +328,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
|
||||
cacheStore := &sync.Map{}
|
||||
cacheStore.Store(embeddedCacheKey, true)
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
|
||||
@ -331,10 +343,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
|
||||
}
|
||||
|
||||
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
|
||||
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
|
||||
ef.DBName = prefix + ef.DBName
|
||||
}
|
||||
|
||||
if ef.PrimaryKey {
|
||||
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
||||
ef.PrimaryKey = true
|
||||
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
||||
@ -350,6 +363,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
ef.HasDefaultValue = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range field.TagSettings {
|
||||
ef.TagSettings[k] = v
|
||||
@ -612,6 +626,14 @@ func (field *Field) setupValuerAndSetter() {
|
||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
||||
case []byte:
|
||||
return field.Set(value, string(data))
|
||||
case time.Time:
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||
} else {
|
||||
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
|
||||
field.ReflectValueOf(value).SetUint(i)
|
||||
@ -671,7 +693,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
case []byte:
|
||||
field.ReflectValueOf(value).SetString(string(data))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
field.ReflectValueOf(value).SetString(fmt.Sprint(data))
|
||||
field.ReflectValueOf(value).SetString(utils.ToString(data))
|
||||
case float64, float32:
|
||||
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||
default:
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
type Namer interface {
|
||||
TableName(table string) string
|
||||
ColumnName(table, column string) string
|
||||
JoinTableName(table string) string
|
||||
JoinTableName(joinTable string) string
|
||||
RelationshipFKName(Relationship) string
|
||||
CheckerName(table, column string) string
|
||||
IndexName(table, column string) string
|
||||
@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
|
||||
|
||||
// JoinTableName convert string to join table name
|
||||
func (ns NamingStrategy) JoinTableName(str string) string {
|
||||
if strings.ToLower(str) == str {
|
||||
return ns.TablePrefix + str
|
||||
}
|
||||
|
||||
if ns.SingularTable {
|
||||
return ns.TablePrefix + toDBName(str)
|
||||
}
|
||||
@ -49,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string {
|
||||
|
||||
// RelationshipFKName generate fk name for relation
|
||||
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
|
||||
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name))
|
||||
return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1)
|
||||
}
|
||||
|
||||
// CheckerName generate checker name
|
||||
func (ns NamingStrategy) CheckerName(table, column string) string {
|
||||
return fmt.Sprintf("chk_%s_%s", table, column)
|
||||
return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1)
|
||||
}
|
||||
|
||||
// IndexName generate index name
|
||||
func (ns NamingStrategy) IndexName(table, column string) string {
|
||||
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
|
||||
idxName = strings.Replace(idxName, ".", "_", -1)
|
||||
|
||||
if utf8.RuneCountInString(idxName) > 64 {
|
||||
h := sha1.New()
|
||||
|
@ -32,3 +32,35 @@ func TestToDBName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamingStrategy(t *testing.T) {
|
||||
var ns = NamingStrategy{
|
||||
TablePrefix: "public.",
|
||||
SingularTable: true,
|
||||
}
|
||||
idxName := ns.IndexName("public.table", "name")
|
||||
|
||||
if idxName != "idx_public_table_name" {
|
||||
t.Errorf("invalid index name generated, got %v", idxName)
|
||||
}
|
||||
|
||||
chkName := ns.CheckerName("public.table", "name")
|
||||
if chkName != "chk_public_table_name" {
|
||||
t.Errorf("invalid checker name generated, got %v", chkName)
|
||||
}
|
||||
|
||||
joinTable := ns.JoinTableName("user_languages")
|
||||
if joinTable != "public.user_languages" {
|
||||
t.Errorf("invalid join table generated, got %v", joinTable)
|
||||
}
|
||||
|
||||
joinTable2 := ns.JoinTableName("UserLanguage")
|
||||
if joinTable2 != "public.user_language" {
|
||||
t.Errorf("invalid join table generated, got %v", joinTable2)
|
||||
}
|
||||
|
||||
tableName := ns.TableName("Company")
|
||||
if tableName != "public.company" {
|
||||
t.Errorf("invalid table name generated, got %v", tableName)
|
||||
}
|
||||
}
|
||||
|
@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||
// use same data type for foreign keys
|
||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
|
||||
if relation.Polymorphic.PolymorphicID.Size == 0 {
|
||||
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryKeyField,
|
||||
@ -254,12 +257,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
})
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: schema.Name + field.Name,
|
||||
Type: schema.ModelType,
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
|
||||
relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields))
|
||||
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
|
||||
|
||||
relName := relation.Schema.Name
|
||||
relRefName := relation.FieldSchema.Name
|
||||
@ -290,11 +299,15 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
// build references
|
||||
for idx, f := range relation.JoinTable.Fields {
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
if f.Creatable || f.Readable || f.Updatable {
|
||||
// use same data type for foreign keys
|
||||
f.DataType = fieldsMap[f.Name].DataType
|
||||
f.GORMDataType = fieldsMap[f.Name].GORMDataType
|
||||
relation.JoinTable.PrimaryFields[idx] = f
|
||||
if f.Size == 0 {
|
||||
f.Size = fieldsMap[f.Name].Size
|
||||
}
|
||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||
ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
||||
|
||||
if ownPriamryField {
|
||||
@ -322,6 +335,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type guessLevel int
|
||||
|
||||
@ -428,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
|
||||
// use same data type for foreign keys
|
||||
foreignField.DataType = primaryFields[idx].DataType
|
||||
foreignField.GORMDataType = primaryFields[idx].GORMDataType
|
||||
if foreignField.Size == 0 {
|
||||
foreignField.Size = primaryFields[idx].Size
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryFields[idx],
|
||||
|
@ -206,16 +206,39 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
|
||||
Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
|
||||
Refer uint
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
|
||||
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
|
||||
References: []Reference{
|
||||
{"ID", "User", "UserReferID", "user_profiles", "", true},
|
||||
{"ID", "Profile", "ProfileRefer", "user_profiles", "", false},
|
||||
{"ID", "User", "UserReferID", "user_profile", "", true},
|
||||
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildReadonlyMany2ManyRelation(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
|
||||
Refer uint
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
|
||||
References: []Reference{
|
||||
{"ID", "User", "UserReferID", "user_profile", "", true},
|
||||
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -267,3 +290,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) {
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestMultipleMany2Many(t *testing.T) {
|
||||
type Thing struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
type Person struct {
|
||||
ID int
|
||||
Likes []Thing `gorm:"many2many:likes"`
|
||||
Dislikes []Thing `gorm:"many2many:dislikes"`
|
||||
}
|
||||
|
||||
checkStructRelation(t, &Person{},
|
||||
Relation{
|
||||
Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
|
||||
JoinTable: JoinTable{Name: "likes", Table: "likes"},
|
||||
References: []Reference{
|
||||
{"ID", "Person", "PersonID", "likes", "", true},
|
||||
{"ID", "Thing", "ThingID", "likes", "", false},
|
||||
},
|
||||
},
|
||||
Relation{
|
||||
Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
|
||||
JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"},
|
||||
References: []Reference{
|
||||
{"ID", "Person", "PersonID", "dislikes", "", true},
|
||||
{"ID", "Thing", "ThingID", "dislikes", "", false},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
@ -133,7 +136,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||
|
||||
if field.DBName != "" {
|
||||
// nonexistence or shortest path or first appear prioritized if has permission
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
|
||||
schema.DBNames = append(schema.DBNames, field.DBName)
|
||||
}
|
||||
@ -219,7 +222,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.Creatable {
|
||||
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package schema_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@ -194,6 +195,7 @@ func TestEmbeddedStruct(t *testing.T) {
|
||||
ID int
|
||||
OwnerID int
|
||||
Name string
|
||||
Ignored string `gorm:"-"`
|
||||
}
|
||||
|
||||
type Corp struct {
|
||||
@ -211,15 +213,87 @@ func TestEmbeddedStruct(t *testing.T) {
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||
{Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
||||
if f.Name != "Ignored" {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type CustomizedNamingStrategy struct {
|
||||
schema.NamingStrategy
|
||||
}
|
||||
|
||||
func (ns CustomizedNamingStrategy) ColumnName(table, column string) string {
|
||||
baseColumnName := ns.NamingStrategy.ColumnName(table, column)
|
||||
|
||||
if table == "" {
|
||||
return baseColumnName
|
||||
}
|
||||
|
||||
s := strings.Split(table, "_")
|
||||
|
||||
var prefix string
|
||||
switch len(s) {
|
||||
case 1:
|
||||
prefix = s[0][:3]
|
||||
case 2:
|
||||
prefix = s[0][:1] + s[1][:2]
|
||||
default:
|
||||
prefix = s[0][:1] + s[1][:1] + s[2][:1]
|
||||
}
|
||||
return prefix + "_" + baseColumnName
|
||||
}
|
||||
|
||||
func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
|
||||
type CorpBase struct {
|
||||
gorm.Model
|
||||
OwnerID string
|
||||
}
|
||||
|
||||
type Company struct {
|
||||
ID int
|
||||
OwnerID int
|
||||
Name string
|
||||
Ignored string `gorm:"-"`
|
||||
}
|
||||
|
||||
type Corp struct {
|
||||
CorpBase
|
||||
Base Company `gorm:"embedded;embeddedPrefix:company_"`
|
||||
}
|
||||
|
||||
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
|
||||
}
|
||||
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||
{Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}},
|
||||
{Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, cropSchema, &f, func(f *schema.Field) {
|
||||
if f.Name != "Ignored" {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
|
||||
return columns, queryValues
|
||||
}
|
||||
}
|
||||
|
||||
type embeddedNamer struct {
|
||||
Table string
|
||||
Namer
|
||||
}
|
||||
|
@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) {
|
||||
}
|
||||
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{
|
||||
clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
|
||||
Value: nil,
|
||||
},
|
||||
}},
|
||||
}
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
|
||||
}
|
||||
|
||||
type SoftDeleteQueryClause struct {
|
||||
|
17
statement.go
17
statement.go
@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case *DB:
|
||||
subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance()
|
||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance()
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
writer.WriteString(subdb.Statement.SQL.String())
|
||||
@ -298,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
} else {
|
||||
values := make([]interface{}, reflectValue.Len())
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
}
|
||||
@ -317,9 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
if field.Readable {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -330,9 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
if field.Readable {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -287,6 +288,30 @@ func TestCreateEmptyStruct(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateEmptySlice(t *testing.T) {
|
||||
var data = []User{}
|
||||
if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice {
|
||||
t.Errorf("no data should be created, got %v", err)
|
||||
}
|
||||
|
||||
var sliceMap = []map[string]interface{}{}
|
||||
if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice {
|
||||
t.Errorf("no data should be created, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInvalidSlice(t *testing.T) {
|
||||
users := []*User{
|
||||
GetUser("invalid_slice_1", Config{}),
|
||||
GetUser("invalid_slice_2", Config{}),
|
||||
nil,
|
||||
}
|
||||
|
||||
if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) {
|
||||
t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithExistingTimestamp(t *testing.T) {
|
||||
user := User{Name: "CreateUserExistingTimestamp"}
|
||||
curTime := now.MustParse("2016-01-01")
|
||||
|
@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) {
|
||||
FieldAllowSave3 string `gorm:"->:false;<-:create"`
|
||||
FieldReadonly string `gorm:"->"`
|
||||
FieldIgnore string `gorm:"-"`
|
||||
AutoUnixCreateTime int64 `gorm:"autocreatetime"`
|
||||
AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"`
|
||||
AutoUnixCreateTime int32 `gorm:"autocreatetime"`
|
||||
AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"`
|
||||
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
|
||||
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"`
|
||||
AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"`
|
||||
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"`
|
||||
AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"`
|
||||
AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"`
|
||||
AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&CustomizeFieldStruct{})
|
||||
@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) {
|
||||
t.Fatalf("invalid result: %#v", result)
|
||||
}
|
||||
|
||||
if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 {
|
||||
if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 {
|
||||
t.Fatalf("invalid create/update unix time: %#v", result)
|
||||
}
|
||||
|
||||
if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 {
|
||||
if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 {
|
||||
t.Fatalf("invalid create/update unix milli time: %#v", result)
|
||||
}
|
||||
|
||||
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 {
|
||||
if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 {
|
||||
t.Fatalf("invalid create/update unix nano time: %#v", result)
|
||||
}
|
||||
|
||||
@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) {
|
||||
var createWithDefaultTimeResult CustomizeFieldStruct
|
||||
DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name)
|
||||
|
||||
if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
|
||||
if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
|
||||
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
|
||||
}
|
||||
|
||||
if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
|
||||
if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 {
|
||||
t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult)
|
||||
}
|
||||
|
||||
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
|
||||
if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
|
||||
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) {
|
||||
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteWithAssociations(t *testing.T) {
|
||||
user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1})
|
||||
|
||||
if err := DB.Create(user).Error; err != nil {
|
||||
t.Fatalf("failed to create user, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil {
|
||||
t.Fatalf("failed to delete user, got error %v", err)
|
||||
}
|
||||
|
||||
for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} {
|
||||
if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value {
|
||||
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} {
|
||||
if count := DB.Model(&user).Association(key).Count(); count != value {
|
||||
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||
users := []User{
|
||||
*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}),
|
||||
*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}),
|
||||
*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}),
|
||||
*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}),
|
||||
}
|
||||
|
||||
if err := DB.Create(users).Error; err != nil {
|
||||
t.Fatalf("failed to create user, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil {
|
||||
t.Fatalf("failed to delete user, got error %v", err)
|
||||
}
|
||||
|
||||
for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} {
|
||||
if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value {
|
||||
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} {
|
||||
if count := DB.Model(&users).Association(key).Count(); count != value {
|
||||
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) {
|
||||
DB.Migrator().DropTable(&AdvancedUser{})
|
||||
|
||||
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
|
||||
if DB.Dialector.Name() != "sqlite" {
|
||||
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
12
tests/go.mod
12
tests/go.mod
@ -7,11 +7,13 @@ require (
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/lib/pq v1.6.0
|
||||
github.com/stevefan1999-personal/gorm-driver-oracle latest
|
||||
gorm.io/driver/mysql v1.0.0
|
||||
gorm.io/driver/postgres v1.0.0
|
||||
gorm.io/driver/sqlite v1.1.0
|
||||
gorm.io/driver/sqlserver v1.0.1
|
||||
gorm.io/gorm v1.9.19
|
||||
gorm.io/driver/mysql v1.0.1
|
||||
gorm.io/driver/postgres v1.0.2
|
||||
gorm.io/driver/sqlite v1.1.3
|
||||
gorm.io/driver/sqlserver v1.0.4
|
||||
gorm.io/gorm v1.20.2
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
||||
replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2
|
||||
|
@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
|
||||
t.Run("Pets", func(t *testing.T) {
|
||||
if len(user.Pets) != len(expect.Pets) {
|
||||
t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
|
||||
t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets))
|
||||
}
|
||||
|
||||
sort.Slice(user.Pets, func(i, j int) bool {
|
||||
@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
|
||||
t.Run("Toys", func(t *testing.T) {
|
||||
if len(user.Toys) != len(expect.Toys) {
|
||||
t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
|
||||
t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys))
|
||||
}
|
||||
|
||||
sort.Slice(user.Toys, func(i, j int) bool {
|
||||
@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
|
||||
t.Run("Team", func(t *testing.T) {
|
||||
if len(user.Team) != len(expect.Team) {
|
||||
t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
|
||||
t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team))
|
||||
}
|
||||
|
||||
sort.Slice(user.Team, func(i, j int) bool {
|
||||
@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
|
||||
t.Run("Languages", func(t *testing.T) {
|
||||
if len(user.Languages) != len(expect.Languages) {
|
||||
t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
|
||||
t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages))
|
||||
}
|
||||
|
||||
sort.Slice(user.Languages, func(i, j int) bool {
|
||||
@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
|
||||
t.Run("Friends", func(t *testing.T) {
|
||||
if len(user.Friends) != len(expect.Friends) {
|
||||
t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
|
||||
t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends))
|
||||
}
|
||||
|
||||
sort.Slice(user.Friends, func(i, j int) bool {
|
||||
|
@ -5,12 +5,14 @@ import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
ID int
|
||||
Name string
|
||||
Addresses []Address `gorm:"many2many:person_addresses;"`
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) {
|
||||
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
||||
t.Fatalf("address should be deleted when clear with unscoped")
|
||||
}
|
||||
|
||||
address2_1 := Address{Name: "address 2-1"}
|
||||
address2_2 := Address{Name: "address 2-2"}
|
||||
person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}}
|
||||
DB.Create(&person2)
|
||||
if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil {
|
||||
t.Fatalf("failed to delete person, got error: %v", err)
|
||||
}
|
||||
|
||||
if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 {
|
||||
t.Errorf("person's addresses expects 2, got %v", count)
|
||||
}
|
||||
|
||||
if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 {
|
||||
t.Errorf("person's addresses expects 2, got %v", count)
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type Blog struct {
|
||||
@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
|
||||
t.Fatalf("EN Blog's tags should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositePrimaryKeysAssociations(t *testing.T) {
|
||||
type Label struct {
|
||||
BookID *uint `gorm:"primarykey"`
|
||||
Name string `gorm:"primarykey"`
|
||||
Value string
|
||||
}
|
||||
|
||||
type Book struct {
|
||||
ID int
|
||||
Name string
|
||||
Labels []Label
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Label{}, &Book{})
|
||||
if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil {
|
||||
t.Fatalf("failed to migrate")
|
||||
}
|
||||
|
||||
book := Book{
|
||||
Name: "my book",
|
||||
Labels: []Label{
|
||||
{Name: "region", Value: "emea"},
|
||||
},
|
||||
}
|
||||
|
||||
DB.Create(&book)
|
||||
|
||||
var result Book
|
||||
if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil {
|
||||
t.Fatalf("failed to preload, got error %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, book, result)
|
||||
}
|
||||
|
@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) {
|
||||
t.Errorf("failed to update with named arg")
|
||||
}
|
||||
|
||||
namedUser.Name1 = "jinzhu-new"
|
||||
namedUser.Name2 = "jinzhu-new2"
|
||||
namedUser.Name3 = "jinzhu-new"
|
||||
|
||||
var result5 NamedUser
|
||||
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil {
|
||||
t.Errorf("failed to update with named arg")
|
||||
}
|
||||
|
||||
AssertEqual(t, result4, namedUser)
|
||||
AssertEqual(t, result5, namedUser)
|
||||
|
||||
var result6 NamedUser
|
||||
if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name
|
||||
AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil {
|
||||
t.Errorf("failed to update with named arg")
|
||||
}
|
||||
|
||||
AssertEqual(t, result6, namedUser)
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
@ -31,6 +33,20 @@ func TestPreloadWithAssociations(t *testing.T) {
|
||||
var user2 User
|
||||
DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
var user3 = *GetUser("preload_with_associations_new", Config{
|
||||
Account: true,
|
||||
Pets: 2,
|
||||
Toys: 3,
|
||||
Company: true,
|
||||
Manager: true,
|
||||
Team: 4,
|
||||
Languages: 3,
|
||||
Friends: 1,
|
||||
})
|
||||
|
||||
DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user3, user)
|
||||
}
|
||||
|
||||
func TestNestedPreload(t *testing.T) {
|
||||
@ -174,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) {
|
||||
CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreloadEmptyData(t *testing.T) {
|
||||
var user = *GetUser("user_without_associations", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name)
|
||||
|
||||
if r, err := json.Marshal(&user); err != nil {
|
||||
t.Errorf("failed to marshal users, got error %v", err)
|
||||
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
|
||||
t.Errorf("json marshal is not empty slice, got %v", string(r))
|
||||
}
|
||||
|
||||
var results []User
|
||||
DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name)
|
||||
|
||||
if r, err := json.Marshal(&results); err != nil {
|
||||
t.Errorf("failed to marshal users, got error %v", err)
|
||||
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
|
||||
t.Errorf("json marshal is not empty slice, got %v", string(r))
|
||||
}
|
||||
}
|
||||
|
52
tests/prepared_stmt_test.go
Normal file
52
tests/prepared_stmt_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestPreparedStmt(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
txCtx := tx.WithContext(ctx)
|
||||
|
||||
user := *GetUser("prepared_stmt", Config{})
|
||||
|
||||
txCtx.Create(&user)
|
||||
|
||||
var result1 User
|
||||
if err := txCtx.Find(&result1, user.ID).Error; err != nil {
|
||||
t.Fatalf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
var result2 User
|
||||
if err := tx.Find(&result2, user.ID).Error; err != nil {
|
||||
t.Fatalf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
user2 := *GetUser("prepared_stmt2", Config{})
|
||||
if err := txCtx.Create(&user2).Error; err == nil {
|
||||
t.Fatalf("should failed to create with timeout context")
|
||||
}
|
||||
|
||||
if err := tx.Create(&user2).Error; err != nil {
|
||||
t.Fatalf("failed to create, got error %v", err)
|
||||
}
|
||||
|
||||
var result3 User
|
||||
if err := tx.Find(&result3, user2.ID).Error; err != nil {
|
||||
t.Fatalf("no error should happen but got %v", err)
|
||||
}
|
||||
}
|
@ -1,11 +1,13 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -61,6 +63,54 @@ func TestFind(t *testing.T) {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if _, ok := first[dbName].(string); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||
}
|
||||
case "Age":
|
||||
if _, ok := first[dbName].(uint); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if _, ok := first[dbName].(*time.Time); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FirstMapWithTable", func(t *testing.T) {
|
||||
var first = map[string]interface{}{}
|
||||
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
} else {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
resultType := reflect.ValueOf(first[dbName]).Type().Name()
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if !strings.Contains(resultType, "string") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||
}
|
||||
case "Age":
|
||||
if !strings.Contains(resultType, "int") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
@ -86,13 +136,29 @@ func TestFind(t *testing.T) {
|
||||
t.Run("FirstSliceOfMap", func(t *testing.T) {
|
||||
var allMap = []map[string]interface{}{}
|
||||
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
t.Errorf("errors happened when query find: %v", err)
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if _, ok := allMap[idx][dbName].(string); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||
}
|
||||
case "Age":
|
||||
if _, ok := allMap[idx][dbName].(uint); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
@ -101,6 +167,58 @@ func TestFind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
|
||||
var allMap = []map[string]interface{}{}
|
||||
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||
t.Errorf("errors happened when query find: %v", err)
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if !strings.Contains(resultType, "string") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||
}
|
||||
case "Age":
|
||||
if !strings.Contains(resultType, "int") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
var models []User
|
||||
if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 {
|
||||
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models))
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
CheckUser(t, models[idx], user)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var none []User
|
||||
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
||||
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryWithAssociation(t *testing.T) {
|
||||
@ -228,6 +346,11 @@ func TestNot(t *testing.T) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{})
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryDB.Not("name = ?", "jinzhu").Find(&User{})
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
@ -309,6 +432,33 @@ func TestPluck(t *testing.T) {
|
||||
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
||||
}
|
||||
}
|
||||
|
||||
var times []time.Time
|
||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
||||
t.Errorf("got error when pluck time: %v", err)
|
||||
}
|
||||
|
||||
for idx, tv := range times {
|
||||
AssertEqual(t, tv, users[idx].CreatedAt)
|
||||
}
|
||||
|
||||
var ptrtimes []*time.Time
|
||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil {
|
||||
t.Errorf("got error when pluck time: %v", err)
|
||||
}
|
||||
|
||||
for idx, tv := range ptrtimes {
|
||||
AssertEqual(t, tv, users[idx].CreatedAt)
|
||||
}
|
||||
|
||||
var nulltimes []sql.NullTime
|
||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil {
|
||||
t.Errorf("got error when pluck time: %v", err)
|
||||
}
|
||||
|
||||
for idx, tv := range nulltimes {
|
||||
AssertEqual(t, tv.Time, users[idx].CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
@ -508,6 +658,7 @@ func TestLimit(t *testing.T) {
|
||||
{Name: "LimitUser3", Age: 20},
|
||||
{Name: "LimitUser4", Age: 10},
|
||||
{Name: "LimitUser5", Age: 20},
|
||||
{Name: "LimitUser6", Age: 20},
|
||||
}
|
||||
|
||||
DB.Create(&users)
|
||||
@ -516,7 +667,7 @@ func TestLimit(t *testing.T) {
|
||||
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
||||
|
||||
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
||||
t.Errorf("Limit should works")
|
||||
t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3))
|
||||
}
|
||||
}
|
||||
|
||||
@ -531,6 +682,7 @@ func TestOffset(t *testing.T) {
|
||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||
t.Errorf("Offset should work")
|
||||
}
|
||||
|
||||
DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
||||
|
||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||
@ -698,3 +850,11 @@ func TestScanNullValue(t *testing.T) {
|
||||
t.Fatalf("failed to query slice data with null age, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryWithTableAndConditions(t *testing.T) {
|
||||
result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"})
|
||||
|
||||
if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -16,14 +17,25 @@ func TestScan(t *testing.T) {
|
||||
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||
|
||||
type result struct {
|
||||
ID uint
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
var res result
|
||||
DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res)
|
||||
if res.Name != user3.Name || res.Age != int(user3.Age) {
|
||||
t.Errorf("Scan into struct should work")
|
||||
DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res)
|
||||
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
||||
}
|
||||
|
||||
DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res)
|
||||
if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2)
|
||||
}
|
||||
|
||||
DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res)
|
||||
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
||||
}
|
||||
|
||||
var doubleAgeRes = &result{}
|
||||
@ -39,11 +51,11 @@ func TestScan(t *testing.T) {
|
||||
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
|
||||
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return strings.Compare(results[i].Name, results[j].Name) < -1
|
||||
return strings.Compare(results[i].Name, results[j].Name) <= -1
|
||||
})
|
||||
|
||||
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
|
||||
t.Errorf("Scan into struct map")
|
||||
t.Errorf("Scan into struct map, got %#v", results)
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,7 +84,21 @@ func TestScanRows(t *testing.T) {
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return strings.Compare(results[i].Name, results[j].Name) <= -1
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||
t.Errorf("Should find expected results")
|
||||
}
|
||||
|
||||
var ages int
|
||||
if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 {
|
||||
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages)
|
||||
}
|
||||
|
||||
var name string
|
||||
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
|
||||
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) {
|
||||
t.Errorf("expects: %v, got %v", expects, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineStringConditions(t *testing.T) {
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ if [ -d tests ]
|
||||
then
|
||||
cd tests
|
||||
cp go.mod go.mod.bak
|
||||
sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod
|
||||
sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
@ -22,7 +22,7 @@ var DB *gorm.DB
|
||||
func init() {
|
||||
var err error
|
||||
if DB, err = OpenTestConnection(); err != nil {
|
||||
log.Printf("failed to connect database, got error %v\n", err)
|
||||
log.Printf("failed to connect database, got error %v", err)
|
||||
os.Exit(1)
|
||||
} else {
|
||||
sqlDB, err := DB.DB()
|
||||
@ -31,7 +31,7 @@ func init() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("failed to connect database, got error %v\n", err)
|
||||
log.Printf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
RunMigrations()
|
||||
|
@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionOnClosedConn(t *testing.T) {
|
||||
DB, err := OpenTestConnection()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
rawDB, _ := DB.DB()
|
||||
rawDB.Close()
|
||||
|
||||
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
return nil
|
||||
}); err == nil {
|
||||
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error {
|
||||
return nil
|
||||
}); err == nil {
|
||||
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
|
||||
var user2 User
|
||||
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
user.Company.Name += "new"
|
||||
user.Manager.Name += "new"
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
||||
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
for _, pet := range user.Pets {
|
||||
pet.Name += "new"
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
|
||||
t.Run("Polymorphic", func(t *testing.T) {
|
||||
var user = *GetUser("update-has-many", Config{})
|
||||
|
||||
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
||||
var user2 User
|
||||
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
for idx := range user.Toys {
|
||||
user.Toys[idx].Name += "new"
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
})
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
|
||||
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
user.Account.Number += "new"
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
|
||||
t.Run("Polymorphic", func(t *testing.T) {
|
||||
var pet = Pet{Name: "create"}
|
||||
|
||||
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
|
||||
var pet2 Pet
|
||||
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
|
||||
CheckPet(t, pet2, pet)
|
||||
|
||||
pet.Toy.Name += "new"
|
||||
if err := DB.Save(&pet).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var pet3 Pet
|
||||
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
|
||||
CheckPet(t, pet2, pet3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var pet4 Pet
|
||||
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
|
||||
CheckPet(t, pet4, pet)
|
||||
})
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
|
||||
var user2 User
|
||||
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
for idx := range user.Friends {
|
||||
user.Friends[idx].Name += "new"
|
||||
}
|
||||
|
||||
for idx := range user.Languages {
|
||||
user.Languages[idx].Name += "new"
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
}
|
||||
|
@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) {
|
||||
var result2 Language
|
||||
DB.First(&result2, "code = ?", "save")
|
||||
AssertEqual(t, result2, lang)
|
||||
|
||||
DB.Table("langs").Migrator().DropTable(&Language{})
|
||||
DB.Table("langs").AutoMigrate(&Language{})
|
||||
|
||||
if err := DB.Table("langs").Save(&lang).Error; err != nil {
|
||||
t.Errorf("no error should happen when creating data, but got %v", err)
|
||||
}
|
||||
|
||||
var result3 Language
|
||||
if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name {
|
||||
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3)
|
||||
}
|
||||
|
||||
lang.Name += "name2"
|
||||
if err := DB.Table("langs").Save(&lang).Error; err != nil {
|
||||
t.Errorf("no error should happen when creating data, but got %v", err)
|
||||
}
|
||||
|
||||
var result4 Language
|
||||
if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name {
|
||||
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ func FileWithLineNum() string {
|
||||
}
|
||||
|
||||
func IsValidDBNameChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
||||
}
|
||||
|
||||
func CheckTruth(val interface{}) bool {
|
||||
@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ToString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case int:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(v, 10)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user