Merge branch 'master' into create_on_confilct
# Conflicts: # tests/create_test.go
This commit is contained in:
		
						commit
						441092d82f
					
				
							
								
								
									
										2
									
								
								.github/workflows/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/labeler.yml
									
									
									
									
										vendored
									
									
								
							| @ -11,7 +11,7 @@ jobs: | ||||
|     name: Label issues and pull requests | ||||
|     steps: | ||||
|       - name: check out | ||||
|         uses: actions/checkout@v3 | ||||
|         uses: actions/checkout@v4 | ||||
| 
 | ||||
|       - name: labeler | ||||
|         uses: jinzhu/super-labeler-action@develop | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/reviewdog.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/reviewdog.yml
									
									
									
									
										vendored
									
									
								
							| @ -6,7 +6,7 @@ jobs: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Check out code into the Go module directory | ||||
|         uses: actions/checkout@v3 | ||||
|         uses: actions/checkout@v4 | ||||
|       - name: golangci-lint | ||||
|         uses: reviewdog/action-golangci-lint@v2 | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										24
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							| @ -16,7 +16,7 @@ jobs: | ||||
|   sqlite: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         go: ['1.19', '1.18'] | ||||
|         go: ['1.21', '1.20', '1.19'] | ||||
|         platform: [ubuntu-latest] # can not run in windows OS | ||||
|     runs-on: ${{ matrix.platform }} | ||||
| 
 | ||||
| @ -27,7 +27,7 @@ jobs: | ||||
|         go-version: ${{ matrix.go }} | ||||
| 
 | ||||
|     - name: Check out code into the Go module directory | ||||
|       uses: actions/checkout@v3 | ||||
|       uses: actions/checkout@v4 | ||||
| 
 | ||||
|     - name: go mod package cache | ||||
|       uses: actions/cache@v3 | ||||
| @ -42,7 +42,7 @@ jobs: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         dbversion: ['mysql:latest', 'mysql:5.7'] | ||||
|         go: ['1.19', '1.18'] | ||||
|         go: ['1.21', '1.20', '1.19'] | ||||
|         platform: [ubuntu-latest] | ||||
|     runs-on: ${{ matrix.platform }} | ||||
| 
 | ||||
| @ -70,7 +70,7 @@ jobs: | ||||
|         go-version: ${{ matrix.go }} | ||||
| 
 | ||||
|     - name: Check out code into the Go module directory | ||||
|       uses: actions/checkout@v3 | ||||
|       uses: actions/checkout@v4 | ||||
| 
 | ||||
|     - name: go mod package cache | ||||
|       uses: actions/cache@v3 | ||||
| @ -85,7 +85,7 @@ jobs: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         dbversion: [ 'mariadb:latest' ] | ||||
|         go: [ '1.19', '1.18' ] | ||||
|         go: ['1.21', '1.20', '1.19'] | ||||
|         platform: [ ubuntu-latest ] | ||||
|     runs-on: ${{ matrix.platform }} | ||||
| 
 | ||||
| @ -113,7 +113,7 @@ jobs: | ||||
|           go-version: ${{ matrix.go }} | ||||
| 
 | ||||
|       - name: Check out code into the Go module directory | ||||
|         uses: actions/checkout@v3 | ||||
|         uses: actions/checkout@v4 | ||||
| 
 | ||||
|       - name: go mod package cache | ||||
|         uses: actions/cache@v3 | ||||
| @ -128,7 +128,7 @@ jobs: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] | ||||
|         go: ['1.19', '1.18'] | ||||
|         go: ['1.21', '1.20', '1.19'] | ||||
|         platform: [ubuntu-latest] # can not run in macOS and Windows | ||||
|     runs-on: ${{ matrix.platform }} | ||||
| 
 | ||||
| @ -156,7 +156,7 @@ jobs: | ||||
|         go-version: ${{ matrix.go }} | ||||
| 
 | ||||
|     - name: Check out code into the Go module directory | ||||
|       uses: actions/checkout@v3 | ||||
|       uses: actions/checkout@v4 | ||||
| 
 | ||||
|     - name: go mod package cache | ||||
|       uses: actions/cache@v3 | ||||
| @ -170,7 +170,7 @@ jobs: | ||||
|   sqlserver: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         go: ['1.19', '1.18'] | ||||
|         go: ['1.21', '1.20', '1.19'] | ||||
|         platform: [ubuntu-latest] # can not run test in macOS and windows | ||||
|     runs-on: ${{ matrix.platform }} | ||||
| 
 | ||||
| @ -199,7 +199,7 @@ jobs: | ||||
|         go-version: ${{ matrix.go }} | ||||
| 
 | ||||
|     - name: Check out code into the Go module directory | ||||
|       uses: actions/checkout@v3 | ||||
|       uses: actions/checkout@v4 | ||||
| 
 | ||||
|     - name: go mod package cache | ||||
|       uses: actions/cache@v3 | ||||
| @ -214,7 +214,7 @@ jobs: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         dbversion: [ 'v6.5.0' ] | ||||
|         go: [ '1.19', '1.18' ] | ||||
|         go: ['1.21', '1.20', '1.19'] | ||||
|         platform: [ ubuntu-latest ] | ||||
|     runs-on: ${{ matrix.platform }} | ||||
| 
 | ||||
| @ -231,7 +231,7 @@ jobs: | ||||
|           go-version: ${{ matrix.go }} | ||||
| 
 | ||||
|       - name: Check out code into the Go module directory | ||||
|         uses: actions/checkout@v3 | ||||
|         uses: actions/checkout@v4 | ||||
| 
 | ||||
| 
 | ||||
|       - name: go mod package cache | ||||
|  | ||||
| @ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly. | ||||
| 
 | ||||
| © Jinzhu, 2013~time.Now | ||||
| 
 | ||||
| Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) | ||||
| Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) | ||||
|  | ||||
| @ -102,13 +102,53 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 		if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||
| 			db.Statement.Schema.PrioritizedPrimaryField != nil && | ||||
| 			db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 			insertID, err := result.LastInsertId() | ||||
| 			insertOk := err == nil && insertID > 0 | ||||
| 			if !insertOk { | ||||
| 				db.AddError(err) | ||||
| 		if db.RowsAffected == 0 { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		var ( | ||||
| 			pkField     *schema.Field | ||||
| 			pkFieldName = "@id" | ||||
| 		) | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 				return | ||||
| 			} | ||||
| 			pkField = db.Statement.Schema.PrioritizedPrimaryField | ||||
| 			pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		insertID, err := result.LastInsertId() | ||||
| 		insertOk := err == nil && insertID > 0 | ||||
| 		if !insertOk { | ||||
| 			db.AddError(err) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// append @id column with value for auto-increment primary key
 | ||||
| 		// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
 | ||||
| 		switch values := db.Statement.Dest.(type) { | ||||
| 		case map[string]interface{}: | ||||
| 			values[pkFieldName] = insertID | ||||
| 		case *map[string]interface{}: | ||||
| 			(*values)[pkFieldName] = insertID | ||||
| 		case []map[string]interface{}, *[]map[string]interface{}: | ||||
| 			mapValues, ok := values.([]map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				if v, ok := values.(*[]map[string]interface{}); ok { | ||||
| 					if *v != nil { | ||||
| 						mapValues = *v | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			for _, mapValue := range mapValues { | ||||
| 				if mapValue != nil { | ||||
| 					mapValue[pkFieldName] = insertID | ||||
| 				} | ||||
| 				insertID += schema.DefaultAutoIncrementIncrement | ||||
| 			} | ||||
| 		default: | ||||
| 			if pkField == nil { | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| @ -121,10 +161,10 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) | ||||
| 						_, isZero := pkField.ValueOf(db.Statement.Context, rv) | ||||
| 						if isZero { | ||||
| 							db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 							db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID -= pkField.AutoIncrementIncrement | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| @ -134,16 +174,16 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { | ||||
| 							db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||
| 						if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { | ||||
| 							db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) | ||||
| 							insertID += pkField.AutoIncrementIncrement | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) | ||||
| 				_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) | ||||
| 				if isZero { | ||||
| 					db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) | ||||
| 					db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) { | ||||
| 	for _, v := range []byte(expr.SQL) { | ||||
| 		if v == '@' && !inName { | ||||
| 			inName = true | ||||
| 			name = []byte{} | ||||
| 			name = name[:0] | ||||
| 		} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { | ||||
| 			if inName { | ||||
| 				if nv, ok := namedMap[string(name)]; ok { | ||||
| @ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { | ||||
| 
 | ||||
| 	switch eq.Value.(type) { | ||||
| 	case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: | ||||
| 		builder.WriteString(" IN (") | ||||
| 		rv := reflect.ValueOf(eq.Value) | ||||
| 		for i := 0; i < rv.Len(); i++ { | ||||
| 			if i > 0 { | ||||
| 				builder.WriteByte(',') | ||||
| 		if rv.Len() == 0 { | ||||
| 			builder.WriteString(" IN (NULL)") | ||||
| 		} else { | ||||
| 			builder.WriteString(" IN (") | ||||
| 			for i := 0; i < rv.Len(); i++ { | ||||
| 				if i > 0 { | ||||
| 					builder.WriteByte(',') | ||||
| 				} | ||||
| 				builder.AddVar(builder, rv.Index(i).Interface()) | ||||
| 			} | ||||
| 			builder.AddVar(builder, rv.Index(i).Interface()) | ||||
| 			builder.WriteByte(')') | ||||
| 		} | ||||
| 		builder.WriteByte(')') | ||||
| 	default: | ||||
| 		if eqNil(eq.Value) { | ||||
| 			builder.WriteString(" IS NULL") | ||||
|  | ||||
| @ -199,6 +199,11 @@ func TestExpression(t *testing.T) { | ||||
| 		}, | ||||
| 		ExpectedVars: []interface{}{"a", "b"}, | ||||
| 		Result:       "`column-name` NOT IN (?,?)", | ||||
| 	}, { | ||||
| 		Expressions: []clause.Expression{ | ||||
| 			clause.Eq{Column: column, Value: []string{}}, | ||||
| 		}, | ||||
| 		Result: "`column-name` IN (NULL)", | ||||
| 	}, { | ||||
| 		Expressions: []clause.Expression{ | ||||
| 			clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, | ||||
|  | ||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							| @ -1,6 +1,6 @@ | ||||
| module gorm.io/gorm | ||||
| 
 | ||||
| go 1.16 | ||||
| go 1.18 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
|  | ||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @ -1,6 +1,4 @@ | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= | ||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
|  | ||||
							
								
								
									
										22
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								gorm.go
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { | ||||
| 		err = config.Dialector.Initialize(db) | ||||
| 
 | ||||
| 		if err != nil { | ||||
| 			if db, err := db.DB(); err == nil { | ||||
| 			if db, _ := db.DB(); db != nil { | ||||
| 				_ = db.Close() | ||||
| 			} | ||||
| 		} | ||||
| @ -374,9 +375,11 @@ func (db *DB) AddError(err error) error { | ||||
| // DB returns `*sql.DB`
 | ||||
| func (db *DB) DB() (*sql.DB, error) { | ||||
| 	connPool := db.ConnPool | ||||
| 
 | ||||
| 	if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { | ||||
| 		return connector.GetDBConnWithContext(db) | ||||
| 	if db.Statement != nil && db.Statement.ConnPool != nil { | ||||
| 		connPool = db.Statement.ConnPool | ||||
| 	} | ||||
| 	if tx, ok := connPool.(*sql.Tx); ok && tx != nil { | ||||
| 		return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil | ||||
| 	} | ||||
| 
 | ||||
| 	if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { | ||||
| @ -399,11 +402,12 @@ func (db *DB) getInstance() *DB { | ||||
| 		if db.clone == 1 { | ||||
| 			// clone with new statement
 | ||||
| 			tx.Statement = &Statement{ | ||||
| 				DB:       tx, | ||||
| 				ConnPool: db.Statement.ConnPool, | ||||
| 				Context:  db.Statement.Context, | ||||
| 				Clauses:  map[string]clause.Clause{}, | ||||
| 				Vars:     make([]interface{}, 0, 8), | ||||
| 				DB:        tx, | ||||
| 				ConnPool:  db.Statement.ConnPool, | ||||
| 				Context:   db.Statement.Context, | ||||
| 				Clauses:   map[string]clause.Clause{}, | ||||
| 				Vars:      make([]interface{}, 0, 8), | ||||
| 				SkipHooks: db.Statement.SkipHooks, | ||||
| 			} | ||||
| 		} else { | ||||
| 			// with clone statement
 | ||||
|  | ||||
| @ -77,12 +77,6 @@ type GetDBConnector interface { | ||||
| 	GetDBConn() (*sql.DB, error) | ||||
| } | ||||
| 
 | ||||
| // GetDBConnectorWithContext represents SQL db connector which takes into
 | ||||
| // account the current database context
 | ||||
| type GetDBConnectorWithContext interface { | ||||
| 	GetDBConnWithContext(db *DB) (*sql.DB, error) | ||||
| } | ||||
| 
 | ||||
| // Rows rows interface
 | ||||
| type Rows interface { | ||||
| 	Columns() ([]string, error) | ||||
|  | ||||
| @ -69,7 +69,7 @@ type Interface interface { | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	// Discard Discard logger will print any log to io.Discard
 | ||||
| 	// Discard logger will print any log to io.Discard
 | ||||
| 	Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) | ||||
| 	// Default Default logger
 | ||||
| 	Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ | ||||
| @ -78,7 +78,7 @@ var ( | ||||
| 		IgnoreRecordNotFoundError: false, | ||||
| 		Colorful:                  true, | ||||
| 	}) | ||||
| 	// Recorder Recorder logger records running SQL into a recorder instance
 | ||||
| 	// Recorder logger records running SQL into a recorder instance
 | ||||
| 	Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} | ||||
| ) | ||||
| 
 | ||||
| @ -129,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface { | ||||
| } | ||||
| 
 | ||||
| // Info print info
 | ||||
| func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { | ||||
| func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { | ||||
| 	if l.LogLevel >= Info { | ||||
| 		l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Warn print warn messages
 | ||||
| func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { | ||||
| func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { | ||||
| 	if l.LogLevel >= Warn { | ||||
| 		l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Error print error messages
 | ||||
| func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { | ||||
| func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { | ||||
| 	if l.LogLevel >= Error { | ||||
| 		l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Trace print sql message
 | ||||
| func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { | ||||
| //
 | ||||
| //nolint:cyclop
 | ||||
| func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { | ||||
| 	if l.LogLevel <= Silent { | ||||
| 		return | ||||
| 	} | ||||
| @ -182,8 +184,8 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Trace print sql message
 | ||||
| func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { | ||||
| // ParamsFilter filter params
 | ||||
| func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { | ||||
| 	if l.Config.ParameterizedQueries { | ||||
| 		return sql, nil | ||||
| 	} | ||||
| @ -198,8 +200,8 @@ type traceRecorder struct { | ||||
| 	Err          error | ||||
| } | ||||
| 
 | ||||
| // New new trace recorder
 | ||||
| func (l traceRecorder) New() *traceRecorder { | ||||
| // New trace recorder
 | ||||
| func (l *traceRecorder) New() *traceRecorder { | ||||
| 	return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -28,8 +28,10 @@ func isPrintable(s string) bool { | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // A list of Go types that should be converted to SQL primitives
 | ||||
| var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} | ||||
| 
 | ||||
| // RegEx matches only numeric values
 | ||||
| var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) | ||||
| 
 | ||||
| // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
 | ||||
| @ -93,8 +95,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a | ||||
| 			} | ||||
| 		case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: | ||||
| 			vars[idx] = utils.ToString(v) | ||||
| 		case float64, float32: | ||||
| 			vars[idx] = fmt.Sprintf("%.6f", v) | ||||
| 		case float32: | ||||
| 			vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) | ||||
| 		case float64: | ||||
| 			vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) | ||||
| 		case string: | ||||
| 			vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper | ||||
| 		default: | ||||
|  | ||||
| @ -57,43 +57,55 @@ func TestExplainSQL(t *testing.T) { | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", | ||||
| 			NumericRegexp: regexp.MustCompile(`@p(\d+)`), | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", | ||||
| 			NumericRegexp: regexp.MustCompile(`\$(\d+)`), | ||||
| 			Vars:          []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", | ||||
| 			NumericRegexp: regexp.MustCompile(`@p(\d+)`), | ||||
| 			Vars:          []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -87,6 +87,8 @@ type Migrator interface { | ||||
| 	DropColumn(dst interface{}, field string) error | ||||
| 	AlterColumn(dst interface{}, field string) error | ||||
| 	MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error | ||||
| 	// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
 | ||||
| 	MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error | ||||
| 	HasColumn(dst interface{}, field string) bool | ||||
| 	RenameColumn(dst interface{}, oldName, field string) error | ||||
| 	ColumnTypes(dst interface{}) ([]ColumnType, error) | ||||
|  | ||||
| @ -16,8 +16,19 @@ import ( | ||||
| 	"gorm.io/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| // This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
 | ||||
| // with a possible trailing non-digit character (\D?).
 | ||||
| 
 | ||||
| // For example, values that can pass this regular expression are:
 | ||||
| // - "123"
 | ||||
| // - "abc456"
 | ||||
| // -"%$#@789"
 | ||||
| var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) | ||||
| 
 | ||||
| // TODO:? Create const vars for raw sql queries ?
 | ||||
| 
 | ||||
| var _ gorm.Migrator = (*Migrator)(nil) | ||||
| 
 | ||||
| // Migrator m struct
 | ||||
| type Migrator struct { | ||||
| 	Config | ||||
| @ -208,7 +219,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { | ||||
| 				field := stmt.Schema.FieldsByDBName[dbName] | ||||
| 				if !field.IgnoreMigration { | ||||
| 					createTableSQL += "? ?" | ||||
| 					hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") | ||||
| 					hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") | ||||
| 					values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) | ||||
| 					createTableSQL += "," | ||||
| 				} | ||||
| @ -530,6 +541,26 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { | ||||
| 	unique, ok := columnType.Unique() | ||||
| 	if !ok || field.PrimaryKey { | ||||
| 		return nil // skip primary key
 | ||||
| 	} | ||||
| 	// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
 | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		// We're currently only receiving boolean values on `Unique` tag,
 | ||||
| 		// so the UniqueConstraint name is fixed
 | ||||
| 		constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) | ||||
| 		if unique && !field.Unique { | ||||
| 			return m.DB.Migrator().DropConstraint(value, constraint) | ||||
| 		} | ||||
| 		if !unique && field.Unique { | ||||
| 			return m.DB.Migrator().CreateConstraint(value, constraint) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
 | ||||
| func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { | ||||
| 	columnTypes := make([]gorm.ColumnType, 0) | ||||
|  | ||||
| @ -31,14 +31,14 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { | ||||
| 	if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { | ||||
| 		return dbConnector.GetDBConn() | ||||
| 	} | ||||
| 
 | ||||
| 	if sqldb, ok := db.ConnPool.(*sql.DB); ok { | ||||
| 		return sqldb, nil | ||||
| 	} | ||||
| 
 | ||||
| 	if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { | ||||
| 		return dbConnector.GetDBConn() | ||||
| 	} | ||||
| 
 | ||||
| 	return nil, ErrInvalidDB | ||||
| } | ||||
| 
 | ||||
| @ -54,15 +54,15 @@ func (db *PreparedStmtDB) Close() { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) Reset() { | ||||
| 	db.Mux.Lock() | ||||
| 	defer db.Mux.Unlock() | ||||
| func (sdb *PreparedStmtDB) Reset() { | ||||
| 	sdb.Mux.Lock() | ||||
| 	defer sdb.Mux.Unlock() | ||||
| 
 | ||||
| 	for _, stmt := range db.Stmts { | ||||
| 	for _, stmt := range sdb.Stmts { | ||||
| 		go stmt.Close() | ||||
| 	} | ||||
| 	db.PreparedSQL = make([]string, 0, 100) | ||||
| 	db.Stmts = make(map[string]*Stmt) | ||||
| 	sdb.PreparedSQL = make([]string, 0, 100) | ||||
| 	sdb.Stmts = make(map[string]*Stmt) | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { | ||||
| @ -127,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn | ||||
| 		tx, err := beginner.BeginTx(ctx, opt) | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||
| 	} | ||||
| 
 | ||||
| 	beginner, ok := db.ConnPool.(ConnPoolBeginner) | ||||
| 	if !ok { | ||||
| 		return nil, ErrInvalidTransaction | ||||
| 	} | ||||
| 
 | ||||
| 	connPool, err := beginner.BeginTx(ctx, opt) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if tx, ok := connPool.(Tx); ok { | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil | ||||
| 	} | ||||
| 	return nil, ErrInvalidTransaction | ||||
| } | ||||
| 
 | ||||
| @ -172,6 +185,10 @@ type PreparedStmtTX struct { | ||||
| 	PreparedStmtDB *PreparedStmtDB | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { | ||||
| 	return db.PreparedStmtDB.GetDBConn() | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) Commit() error { | ||||
| 	if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { | ||||
| 		return tx.Tx.Commit() | ||||
|  | ||||
| @ -49,6 +49,8 @@ const ( | ||||
| 	Bytes  DataType = "bytes" | ||||
| ) | ||||
| 
 | ||||
| const DefaultAutoIncrementIncrement int64 = 1 | ||||
| 
 | ||||
| // Field is the representation of model schema's field
 | ||||
| type Field struct { | ||||
| 	Name                   string | ||||
| @ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 		NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), | ||||
| 		Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]), | ||||
| 		Comment:                tagSetting["COMMENT"], | ||||
| 		AutoIncrementIncrement: 1, | ||||
| 		AutoIncrementIncrement: DefaultAutoIncrementIncrement, | ||||
| 	} | ||||
| 
 | ||||
| 	for field.IndirectFieldType.Kind() == reflect.Ptr { | ||||
|  | ||||
| @ -19,6 +19,7 @@ type Namer interface { | ||||
| 	RelationshipFKName(Relationship) string | ||||
| 	CheckerName(table, column string) string | ||||
| 	IndexName(table, column string) string | ||||
| 	UniqueName(table, column string) string | ||||
| } | ||||
| 
 | ||||
| // Replacer replacer interface like strings.Replacer
 | ||||
| @ -26,6 +27,8 @@ type Replacer interface { | ||||
| 	Replace(name string) string | ||||
| } | ||||
| 
 | ||||
| var _ Namer = (*NamingStrategy)(nil) | ||||
| 
 | ||||
| // NamingStrategy tables, columns naming strategy
 | ||||
| type NamingStrategy struct { | ||||
| 	TablePrefix         string | ||||
| @ -85,6 +88,11 @@ func (ns NamingStrategy) IndexName(table, column string) string { | ||||
| 	return ns.formatName("idx", table, ns.toDBName(column)) | ||||
| } | ||||
| 
 | ||||
| // UniqueName generate unique constraint name
 | ||||
| func (ns NamingStrategy) UniqueName(table, column string) string { | ||||
| 	return ns.formatName("uni", table, ns.toDBName(column)) | ||||
| } | ||||
| 
 | ||||
| func (ns NamingStrategy) formatName(prefix, table, name string) string { | ||||
| 	formattedName := strings.ReplaceAll(strings.Join([]string{ | ||||
| 		prefix, table, name, | ||||
|  | ||||
| @ -13,6 +13,20 @@ import ( | ||||
| 	"gorm.io/gorm/logger" | ||||
| ) | ||||
| 
 | ||||
| type callbackType string | ||||
| 
 | ||||
| const ( | ||||
| 	callbackTypeBeforeCreate callbackType = "BeforeCreate" | ||||
| 	callbackTypeBeforeUpdate callbackType = "BeforeUpdate" | ||||
| 	callbackTypeAfterCreate  callbackType = "AfterCreate" | ||||
| 	callbackTypeAfterUpdate  callbackType = "AfterUpdate" | ||||
| 	callbackTypeBeforeSave   callbackType = "BeforeSave" | ||||
| 	callbackTypeAfterSave    callbackType = "AfterSave" | ||||
| 	callbackTypeBeforeDelete callbackType = "BeforeDelete" | ||||
| 	callbackTypeAfterDelete  callbackType = "AfterDelete" | ||||
| 	callbackTypeAfterFind    callbackType = "AfterFind" | ||||
| ) | ||||
| 
 | ||||
| // ErrUnsupportedDataType unsupported data type
 | ||||
| var ErrUnsupportedDataType = errors.New("unsupported data type") | ||||
| 
 | ||||
| @ -288,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} | ||||
| 	for _, name := range callbacks { | ||||
| 		if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { | ||||
| 	callbackTypes := []callbackType{ | ||||
| 		callbackTypeBeforeCreate, callbackTypeAfterCreate, | ||||
| 		callbackTypeBeforeUpdate, callbackTypeAfterUpdate, | ||||
| 		callbackTypeBeforeSave, callbackTypeAfterSave, | ||||
| 		callbackTypeBeforeDelete, callbackTypeAfterDelete, | ||||
| 		callbackTypeAfterFind, | ||||
| 	} | ||||
| 	for _, cbName := range callbackTypes { | ||||
| 		if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { | ||||
| 			switch methodValue.Type().String() { | ||||
| 			case "func(*gorm.DB) error": // TODO hack
 | ||||
| 				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) | ||||
| 				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) | ||||
| 			default: | ||||
| 				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) | ||||
| 				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @ -349,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | ||||
| 	return schema, schema.err | ||||
| } | ||||
| 
 | ||||
| // This unrolling is needed to show to the compiler the exact set of methods
 | ||||
| // that can be used on the modelType.
 | ||||
| // Prior to go1.22 any use of MethodByName would cause the linker to
 | ||||
| // abandon dead code elimination for the entire binary.
 | ||||
| // As of go1.22 the compiler supports one special case of a string constant
 | ||||
| // being passed to MethodByName. For enterprise customers or those building
 | ||||
| // large binaries, this gives a significant reduction in binary size.
 | ||||
| // https://github.com/golang/go/issues/62257
 | ||||
| func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { | ||||
| 	switch cbType { | ||||
| 	case callbackTypeBeforeCreate: | ||||
| 		return modelType.MethodByName(string(callbackTypeBeforeCreate)) | ||||
| 	case callbackTypeAfterCreate: | ||||
| 		return modelType.MethodByName(string(callbackTypeAfterCreate)) | ||||
| 	case callbackTypeBeforeUpdate: | ||||
| 		return modelType.MethodByName(string(callbackTypeBeforeUpdate)) | ||||
| 	case callbackTypeAfterUpdate: | ||||
| 		return modelType.MethodByName(string(callbackTypeAfterUpdate)) | ||||
| 	case callbackTypeBeforeSave: | ||||
| 		return modelType.MethodByName(string(callbackTypeBeforeSave)) | ||||
| 	case callbackTypeAfterSave: | ||||
| 		return modelType.MethodByName(string(callbackTypeAfterSave)) | ||||
| 	case callbackTypeBeforeDelete: | ||||
| 		return modelType.MethodByName(string(callbackTypeBeforeDelete)) | ||||
| 	case callbackTypeAfterDelete: | ||||
| 		return modelType.MethodByName(string(callbackTypeAfterDelete)) | ||||
| 	case callbackTypeAfterFind: | ||||
| 		return modelType.MethodByName(string(callbackTypeAfterFind)) | ||||
| 	default: | ||||
| 		return reflect.ValueOf(nil) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { | ||||
| 	modelType := reflect.ValueOf(dest).Type() | ||||
| 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { | ||||
|  | ||||
| @ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, | ||||
| 		notZero, zero bool | ||||
| 	) | ||||
| 
 | ||||
| 	if reflectValue.Kind() == reflect.Ptr || | ||||
| 		reflectValue.Kind() == reflect.Interface { | ||||
| 		reflectValue = reflectValue.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	switch reflectValue.Kind() { | ||||
| 	case reflect.Struct: | ||||
| 		results = [][]interface{}{make([]interface{}, len(fields))} | ||||
|  | ||||
| @ -278,8 +278,6 @@ func TestBelongsToAssociationUnscoped(t *testing.T) { | ||||
| 		t.Fatalf("failed to create items, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	tx = tx.Debug() | ||||
| 
 | ||||
| 	// test replace
 | ||||
| 	if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ | ||||
| 		Logo: "updated logo", | ||||
|  | ||||
| @ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	var count2 int64 | ||||
| 	if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { | ||||
| 	if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { | ||||
| 		t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||
| 	} | ||||
| 	if count2 != 2 { | ||||
|  | ||||
| @ -2,6 +2,7 @@ package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| @ -648,3 +649,180 @@ func TestCreateOnConflictWithDefaultJSON(t *testing.T) { | ||||
| 	AssertEqual(t, err, nil) | ||||
| 	AssertEqual(t, v2.Params, datatypes.JSONMap{"foo": "new-bar"}) | ||||
| } | ||||
| 
 | ||||
| func TestCreateFromMapWithoutPK(t *testing.T) { | ||||
| 	if !isMysql() { | ||||
| 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 1: one record, create from map[string]interface{}
 | ||||
| 	mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} | ||||
| 	if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := mapValue1["id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	var result1 User | ||||
| 	if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var idVal int64 | ||||
| 	_, ok := mapValue1["id"].(uint) | ||||
| 	if ok { | ||||
| 		t.Skipf("This test case skipped, because the db supports returning") | ||||
| 	} | ||||
| 
 | ||||
| 	idVal, ok = mapValue1["id"].(int64) | ||||
| 	if !ok { | ||||
| 		t.Fatal("ret result missing id") | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(result1.ID) != idVal { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case2: one record, create from *map[string]interface{}
 | ||||
| 	mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} | ||||
| 	if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := mapValue2["id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||
| 	} | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	_, ok = mapValue2["id"].(uint) | ||||
| 	if ok { | ||||
| 		t.Skipf("This test case skipped, because the db supports returning") | ||||
| 	} | ||||
| 
 | ||||
| 	idVal, ok = mapValue2["id"].(int64) | ||||
| 	if !ok { | ||||
| 		t.Fatal("ret result missing id") | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(result2.ID) != idVal { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 3: records
 | ||||
| 	values := []map[string]interface{}{ | ||||
| 		{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, | ||||
| 	} | ||||
| 
 | ||||
| 	beforeLen := len(values) | ||||
| 	if err := DB.Model(&User{}).Create(&values).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// mariadb with returning, values will be appended with id map
 | ||||
| 	if len(values) == beforeLen*2 { | ||||
| 		t.Skipf("This test case skipped, because the db supports returning") | ||||
| 	} | ||||
| 
 | ||||
| 	for i := range values { | ||||
| 		v, ok := values[i]["id"] | ||||
| 		if !ok { | ||||
| 			t.Fatal("failed to create data from map with table, returning map has no primary key") | ||||
| 		} | ||||
| 
 | ||||
| 		var result User | ||||
| 		if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { | ||||
| 			t.Fatalf("failed to create from map, got error %v", err) | ||||
| 		} | ||||
| 		if int64(result.ID) != v.(int64) { | ||||
| 			t.Fatal("failed to create data from map with table, @id != id") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 	if !isMysql() { | ||||
| 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||
| 	} | ||||
| 	tableDB := DB.Table("`users`") | ||||
| 
 | ||||
| 	// case 1: create from map[string]interface{}
 | ||||
| 	record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} | ||||
| 	if err := tableDB.Create(record).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map with table, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := record["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	var res map[string]interface{} | ||||
| 	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res["id"].(uint64)) != record["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 2: create from *map[string]interface{}
 | ||||
| 	record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} | ||||
| 	tableDB2 := DB.Table("users") | ||||
| 	if err := tableDB2.Create(&record1).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 	if _, ok := record1["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	var res1 map[string]interface{} | ||||
| 	if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res1["id"].(uint64)) != record1["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	// case 3: create from []map[string]interface{}
 | ||||
| 	records := []map[string]interface{}{ | ||||
| 		{"name": "create_from_map_with_table_2", "age": 19}, | ||||
| 		{"name": "create_from_map_with_table_3", "age": 20}, | ||||
| 	} | ||||
| 
 | ||||
| 	tableDB = DB.Table("users") | ||||
| 	if err := tableDB.Create(&records).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from slice of map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := records[0]["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := records[1]["@id"]; !ok { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	var res2 map[string]interface{} | ||||
| 	if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { | ||||
| 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var res3 map[string]interface{} | ||||
| 	if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { | ||||
| 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res2["id"].(uint64)) != records[0]["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res3["id"].(uint64)) != records[1]["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // only sqlite, postgres support returning
 | ||||
| // only sqlite, postgres, sqlserver support returning
 | ||||
| func TestSoftDeleteReturning(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| @ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestDeleteReturning(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -236,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestEmbeddedRelations(t *testing.T) { | ||||
| 	type EmbUser struct { | ||||
| 		gorm.Model | ||||
| 		Name      string | ||||
| 		Age       uint | ||||
| 		Languages []Language `gorm:"many2many:EmbUserSpeak;"` | ||||
| 	} | ||||
| 
 | ||||
| 	type AdvancedUser struct { | ||||
| 		User     `gorm:"embedded"` | ||||
| 		EmbUser  `gorm:"embedded"` | ||||
| 		Advanced bool | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										30
									
								
								tests/go.mod
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								tests/go.mod
									
									
									
									
									
								
							| @ -1,16 +1,32 @@ | ||||
| module gorm.io/gorm/tests | ||||
| 
 | ||||
| go 1.16 | ||||
| go 1.18 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/google/uuid v1.4.0 | ||||
| 	github.com/jinzhu/now v1.1.5 | ||||
| 	github.com/lib/pq v1.10.9 | ||||
| 	gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 | ||||
| 	gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 | ||||
| 	gorm.io/driver/sqlite v1.5.2 | ||||
| 	gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a | ||||
| 	gorm.io/gorm v1.25.2-0.20230610234218-206613868439 | ||||
| 	gorm.io/driver/mysql v1.5.2 | ||||
| 	gorm.io/driver/postgres v1.5.4 | ||||
| 	gorm.io/driver/sqlite v1.5.4 | ||||
| 	gorm.io/driver/sqlserver v1.5.2 | ||||
| 	gorm.io/gorm v1.25.5 | ||||
| ) | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/go-sql-driver/mysql v1.7.1 // indirect | ||||
| 	github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect | ||||
| 	github.com/golang-sql/sqlexp v0.1.0 // indirect | ||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||
| 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | ||||
| 	github.com/jackc/pgx/v5 v5.5.0 // indirect | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v1.14.18 // indirect | ||||
| 	github.com/microsoft/go-mssqldb v1.6.0 // indirect | ||||
| 	golang.org/x/crypto v0.15.0 // indirect | ||||
| 	golang.org/x/text v0.14.0 // indirect | ||||
| ) | ||||
| 
 | ||||
| replace gorm.io/gorm => ../ | ||||
| 
 | ||||
| replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 | ||||
|  | ||||
| @ -265,6 +265,10 @@ func isTiDB() bool { | ||||
| 	return os.Getenv("GORM_DIALECT") == "tidb" | ||||
| } | ||||
| 
 | ||||
| func isMysql() bool { | ||||
| 	return os.Getenv("GORM_DIALECT") == "mysql" | ||||
| } | ||||
| 
 | ||||
| func db(unscoped bool) *gorm.DB { | ||||
| 	if unscoped { | ||||
| 		return DB.Unscoped() | ||||
|  | ||||
| @ -862,6 +862,48 @@ func TestMigrateWithSpecialName(t *testing.T) { | ||||
| 	AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) | ||||
| } | ||||
| 
 | ||||
| // https://github.com/go-gorm/gorm/issues/4760
 | ||||
| func TestMigrateAutoIncrement(t *testing.T) { | ||||
| 	type AutoIncrementStruct struct { | ||||
| 		ID     int64   `gorm:"primarykey;autoIncrement"` | ||||
| 		Field1 uint32  `gorm:"column:field1"` | ||||
| 		Field2 float32 `gorm:"column:field2"` | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil { | ||||
| 		t.Fatalf("AutoMigrate err: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	const ROWS = 10 | ||||
| 	for idx := 0; idx < ROWS; idx++ { | ||||
| 		if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil { | ||||
| 			t.Fatalf("create auto_increment_struct fail, err: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	rows := make([]*AutoIncrementStruct, 0, ROWS) | ||||
| 	if err := DB.Order("id ASC").Find(&rows).Error; err != nil { | ||||
| 		t.Fatalf("find auto_increment_struct fail, err: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	ids := make([]int64, 0, len(rows)) | ||||
| 	for _, row := range rows { | ||||
| 		ids = append(ids, row.ID) | ||||
| 	} | ||||
| 	lastID := ids[len(ids)-1] | ||||
| 
 | ||||
| 	if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil { | ||||
| 		t.Fatalf("delete auto_increment_struct fail, err: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	newRow := &AutoIncrementStruct{} | ||||
| 	if err := DB.Create(newRow).Error; err != nil { | ||||
| 		t.Fatalf("create auto_increment_struct fail, err: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	AssertEqual(t, newRow.ID, lastID+1) | ||||
| } | ||||
| 
 | ||||
| // https://github.com/go-gorm/gorm/issues/5320
 | ||||
| func TestPrimarykeyID(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "postgres" { | ||||
| @ -1598,3 +1640,48 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestTableType(t *testing.T) { | ||||
| 	// currently it is only supported for mysql driver
 | ||||
| 	if !isMysql() { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	const tblName = "cities" | ||||
| 	const tblSchema = "gorm" | ||||
| 	const tblType = "BASE TABLE" | ||||
| 	const tblComment = "foobar comment" | ||||
| 
 | ||||
| 	type City struct { | ||||
| 		gorm.Model | ||||
| 		Name string `gorm:"unique"` | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Migrator().DropTable(&City{}) | ||||
| 
 | ||||
| 	if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { | ||||
| 		t.Fatalf("failed to migrate cities tables, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	tableType, err := DB.Table("cities").Migrator().TableType(&City{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to get table type, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if tableType.Schema() != tblSchema { | ||||
| 		t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema()) | ||||
| 	} | ||||
| 
 | ||||
| 	if tableType.Name() != tblName { | ||||
| 		t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name()) | ||||
| 	} | ||||
| 
 | ||||
| 	if tableType.Type() != tblType { | ||||
| 		t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type()) | ||||
| 	} | ||||
| 
 | ||||
| 	comment, ok := tableType.Comment() | ||||
| 	if !ok || comment != tblComment { | ||||
| 		t.Fatalf("expected comment %s got %s", tblComment, comment) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -429,7 +429,6 @@ func TestEmbedPreload(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	DB = DB.Debug() | ||||
| 	for _, test := range tests { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
| 			actual := Org{} | ||||
|  | ||||
| @ -2,6 +2,7 @@ package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| @ -658,6 +659,18 @@ func TestOrWithAllFields(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Int64 int64 | ||||
| 
 | ||||
| func (v Int64) Value() (driver.Value, error) { | ||||
| 	return v - 1, nil | ||||
| } | ||||
| 
 | ||||
| func (f *Int64) Scan(v interface{}) error { | ||||
| 	y := v.(int64) | ||||
| 	*f = Int64(y + 1) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestPluck(t *testing.T) { | ||||
| 	users := []*User{ | ||||
| 		GetUser("pluck-user1", Config{}), | ||||
| @ -685,6 +698,11 @@ func TestPluck(t *testing.T) { | ||||
| 		t.Errorf("got error when pluck id: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var ids2 []Int64 | ||||
| 	if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil { | ||||
| 		t.Errorf("got error when pluck id: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, name := range names { | ||||
| 		if name != users[idx].Name { | ||||
| 			t.Errorf("Unexpected result on pluck name, got %+v", names) | ||||
| @ -697,6 +715,12 @@ func TestPluck(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, id := range ids2 { | ||||
| 		if int(id) != int(users[idx].ID+1) { | ||||
| 			t.Errorf("Unexpected result on pluck id, got %+v", ids) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	var times []time.Time | ||||
| 	if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { | ||||
| 		t.Errorf("got error when pluck time: %v", err) | ||||
|  | ||||
| @ -43,9 +43,6 @@ func init() { | ||||
| 		} | ||||
| 
 | ||||
| 		RunMigrations() | ||||
| 		if DB.Dialector.Name() == "sqlite" { | ||||
| 			DB.Exec("PRAGMA foreign_keys = ON") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -89,7 +86,10 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { | ||||
| 		db, err = gorm.Open(mysql.Open(dbDSN), cfg) | ||||
| 	default: | ||||
| 		log.Println("testing sqlite3...") | ||||
| 		db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) | ||||
| 		db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) | ||||
| 		if err == nil { | ||||
| 			db.Exec("PRAGMA foreign_keys = ON") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // only sqlite, postgres support returning
 | ||||
| // only sqlite, postgres, sqlserver support returning
 | ||||
| func TestUpdateReturning(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| @ -838,7 +838,7 @@ func TestSaveWithHooks(t *testing.T) { | ||||
| 	saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { | ||||
| 		var newOwner TokenOwner | ||||
| 		if err := DB.Transaction(func(tx *gorm.DB) error { | ||||
| 			if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { | ||||
| 			if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { | ||||
|  | ||||
| @ -89,19 +89,28 @@ func Contains(elems []string, elem string) bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func AssertEqual(src, dst interface{}) bool { | ||||
| 	if !reflect.DeepEqual(src, dst) { | ||||
| 		if valuer, ok := src.(driver.Valuer); ok { | ||||
| 			src, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := dst.(driver.Valuer); ok { | ||||
| 			dst, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		return reflect.DeepEqual(src, dst) | ||||
| func AssertEqual(x, y interface{}) bool { | ||||
| 	if reflect.DeepEqual(x, y) { | ||||
| 		return true | ||||
| 	} | ||||
| 	return true | ||||
| 	if x == nil || y == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	xval := reflect.ValueOf(x) | ||||
| 	yval := reflect.ValueOf(y) | ||||
| 	if xval.Kind() == reflect.Ptr && xval.IsNil() || | ||||
| 		yval.Kind() == reflect.Ptr && yval.IsNil() { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if valuer, ok := x.(driver.Valuer); ok { | ||||
| 		x, _ = valuer.Value() | ||||
| 	} | ||||
| 	if valuer, ok := y.(driver.Valuer); ok { | ||||
| 		y, _ = valuer.Value() | ||||
| 	} | ||||
| 	return reflect.DeepEqual(x, y) | ||||
| } | ||||
| 
 | ||||
| func ToString(value interface{}) string { | ||||
|  | ||||
| @ -98,6 +98,7 @@ func TestAssertEqual(t *testing.T) { | ||||
| 		{"error not equal", errors.New("1"), errors.New("2"), false}, | ||||
| 		{"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, | ||||
| 		{"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, | ||||
| 		{"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false}, | ||||
| 	} | ||||
| 	for _, test := range assertEqualTests { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 black
						black