Refactor tests files
This commit is contained in:
		
							parent
							
								
									5790ba9ef4
								
							
						
					
					
						commit
						8bb05a5a69
					
				| @ -7,7 +7,7 @@ import ( | |||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BenchmarkSelect(b *testing.B) { | func BenchmarkSelect(b *testing.B) { | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var db, _ = gorm.Open(tests.DummyDialector{}, nil) | var db, _ = gorm.Open(tests.DummyDialector{}, nil) | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ import ( | |||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestExpr(t *testing.T) { | func TestExpr(t *testing.T) { | ||||||
|  | |||||||
| @ -1,225 +0,0 @@ | |||||||
| package mssql |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"reflect" |  | ||||||
| 	"sort" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/callbacks" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func Create(db *gorm.DB) { |  | ||||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { |  | ||||||
| 		for _, c := range db.Statement.Schema.CreateClauses { |  | ||||||
| 			db.Statement.AddClause(c) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if db.Statement.SQL.String() == "" { |  | ||||||
| 		setIdentityInsert := false |  | ||||||
| 		c := db.Statement.Clauses["ON CONFLICT"] |  | ||||||
| 		onConflict, hasConflict := c.Expression.(clause.OnConflict) |  | ||||||
| 
 |  | ||||||
| 		if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { |  | ||||||
| 			setIdentityInsert = false |  | ||||||
| 			switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 			case reflect.Struct: |  | ||||||
| 				_, isZero := field.ValueOf(db.Statement.ReflectValue) |  | ||||||
| 				setIdentityInsert = !isZero |  | ||||||
| 			case reflect.Slice: |  | ||||||
| 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { |  | ||||||
| 					_, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)) |  | ||||||
| 					setIdentityInsert = !isZero |  | ||||||
| 					break |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) { |  | ||||||
| 				setIdentityInsert = true |  | ||||||
| 				db.Statement.WriteString("SET IDENTITY_INSERT ") |  | ||||||
| 				db.Statement.WriteQuoted(db.Statement.Table) |  | ||||||
| 				db.Statement.WriteString(" ON;") |  | ||||||
| 			} else { |  | ||||||
| 				setIdentityInsert = false |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 { |  | ||||||
| 			MergeCreate(db, onConflict) |  | ||||||
| 		} else { |  | ||||||
| 			db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) |  | ||||||
| 			db.Statement.Build("INSERT") |  | ||||||
| 			db.Statement.WriteByte(' ') |  | ||||||
| 
 |  | ||||||
| 			db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) |  | ||||||
| 			if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok { |  | ||||||
| 				if len(values.Columns) > 0 { |  | ||||||
| 					db.Statement.WriteByte('(') |  | ||||||
| 					for idx, column := range values.Columns { |  | ||||||
| 						if idx > 0 { |  | ||||||
| 							db.Statement.WriteByte(',') |  | ||||||
| 						} |  | ||||||
| 						db.Statement.WriteQuoted(column) |  | ||||||
| 					} |  | ||||||
| 					db.Statement.WriteByte(')') |  | ||||||
| 
 |  | ||||||
| 					outputInserted(db) |  | ||||||
| 
 |  | ||||||
| 					db.Statement.WriteString(" VALUES ") |  | ||||||
| 
 |  | ||||||
| 					for idx, value := range values.Values { |  | ||||||
| 						if idx > 0 { |  | ||||||
| 							db.Statement.WriteByte(',') |  | ||||||
| 						} |  | ||||||
| 
 |  | ||||||
| 						db.Statement.WriteByte('(') |  | ||||||
| 						db.Statement.AddVar(db.Statement, value...) |  | ||||||
| 						db.Statement.WriteByte(')') |  | ||||||
| 					} |  | ||||||
| 
 |  | ||||||
| 					db.Statement.WriteString(";") |  | ||||||
| 				} else { |  | ||||||
| 					db.Statement.WriteString("DEFAULT VALUES;") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if setIdentityInsert { |  | ||||||
| 			db.Statement.WriteString("SET IDENTITY_INSERT ") |  | ||||||
| 			db.Statement.WriteQuoted(db.Statement.Table) |  | ||||||
| 			db.Statement.WriteString(" OFF;") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if !db.DryRun { |  | ||||||
| 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 
 |  | ||||||
| 		if err == nil { |  | ||||||
| 			defer rows.Close() |  | ||||||
| 
 |  | ||||||
| 			if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { |  | ||||||
| 				sortedKeys := []string{} |  | ||||||
| 				for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { |  | ||||||
| 					sortedKeys = append(sortedKeys, field.DBName) |  | ||||||
| 				} |  | ||||||
| 				sort.Strings(sortedKeys) |  | ||||||
| 
 |  | ||||||
| 				returnningFields := make([]*schema.Field, len(sortedKeys)) |  | ||||||
| 				for idx, key := range sortedKeys { |  | ||||||
| 					returnningFields[idx] = db.Statement.Schema.LookUpField(key) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				values := make([]interface{}, len(returnningFields)) |  | ||||||
| 
 |  | ||||||
| 				switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 				case reflect.Slice, reflect.Array: |  | ||||||
| 					for rows.Next() { |  | ||||||
| 						for idx, field := range returnningFields { |  | ||||||
| 							values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() |  | ||||||
| 						} |  | ||||||
| 
 |  | ||||||
| 						db.RowsAffected++ |  | ||||||
| 						db.AddError(rows.Scan(values...)) |  | ||||||
| 					} |  | ||||||
| 				case reflect.Struct: |  | ||||||
| 					for idx, field := range returnningFields { |  | ||||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() |  | ||||||
| 					} |  | ||||||
| 
 |  | ||||||
| 					if rows.Next() { |  | ||||||
| 						db.RowsAffected++ |  | ||||||
| 						db.AddError(rows.Scan(values...)) |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			db.AddError(err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { |  | ||||||
| 	values := callbacks.ConvertToCreateValues(db.Statement) |  | ||||||
| 
 |  | ||||||
| 	db.Statement.WriteString("MERGE INTO ") |  | ||||||
| 	db.Statement.WriteQuoted(db.Statement.Table) |  | ||||||
| 	db.Statement.WriteString(" USING (VALUES") |  | ||||||
| 	for idx, value := range values.Values { |  | ||||||
| 		if idx > 0 { |  | ||||||
| 			db.Statement.WriteByte(',') |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		db.Statement.WriteByte('(') |  | ||||||
| 		db.Statement.AddVar(db.Statement, value...) |  | ||||||
| 		db.Statement.WriteByte(')') |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db.Statement.WriteString(") AS source (") |  | ||||||
| 	for idx, column := range values.Columns { |  | ||||||
| 		if idx > 0 { |  | ||||||
| 			db.Statement.WriteByte(',') |  | ||||||
| 		} |  | ||||||
| 		db.Statement.WriteQuoted(column.Name) |  | ||||||
| 	} |  | ||||||
| 	db.Statement.WriteString(") ON ") |  | ||||||
| 
 |  | ||||||
| 	var where clause.Where |  | ||||||
| 	for _, field := range db.Statement.Schema.PrimaryFields { |  | ||||||
| 		where.Exprs = append(where.Exprs, clause.Eq{ |  | ||||||
| 			Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, |  | ||||||
| 			Value:  clause.Column{Table: "source", Name: field.DBName}, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	where.Build(db.Statement) |  | ||||||
| 
 |  | ||||||
| 	if len(onConflict.DoUpdates) > 0 { |  | ||||||
| 		db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") |  | ||||||
| 		onConflict.DoUpdates.Build(db.Statement) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") |  | ||||||
| 
 |  | ||||||
| 	for idx, column := range values.Columns { |  | ||||||
| 		if idx > 0 { |  | ||||||
| 			db.Statement.WriteByte(',') |  | ||||||
| 		} |  | ||||||
| 		db.Statement.WriteQuoted(column.Name) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db.Statement.WriteString(") VALUES (") |  | ||||||
| 
 |  | ||||||
| 	for idx, column := range values.Columns { |  | ||||||
| 		if idx > 0 { |  | ||||||
| 			db.Statement.WriteByte(',') |  | ||||||
| 		} |  | ||||||
| 		db.Statement.WriteQuoted(clause.Column{ |  | ||||||
| 			Table: "source", |  | ||||||
| 			Name:  column.Name, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db.Statement.WriteString(")") |  | ||||||
| 	outputInserted(db) |  | ||||||
| 	db.Statement.WriteString(";") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func outputInserted(db *gorm.DB) { |  | ||||||
| 	if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { |  | ||||||
| 		sortedKeys := []string{} |  | ||||||
| 		for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { |  | ||||||
| 			sortedKeys = append(sortedKeys, field.DBName) |  | ||||||
| 		} |  | ||||||
| 		sort.Strings(sortedKeys) |  | ||||||
| 
 |  | ||||||
| 		db.Statement.WriteString(" OUTPUT") |  | ||||||
| 		for idx, key := range sortedKeys { |  | ||||||
| 			if idx > 0 { |  | ||||||
| 				db.Statement.WriteString(",") |  | ||||||
| 			} |  | ||||||
| 			db.Statement.WriteString(" INSERTED.") |  | ||||||
| 			db.Statement.AddVar(db.Statement, clause.Column{Name: key}) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,142 +0,0 @@ | |||||||
| package mssql |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Migrator struct { |  | ||||||
| 	migrator.Migrator |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasTable(value interface{}) bool { |  | ||||||
| 	var count int |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", |  | ||||||
| 			stmt.Table, m.CurrentDatabase(), |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) RenameTable(oldName, newName interface{}) error { |  | ||||||
| 	var oldTable, newTable string |  | ||||||
| 	if v, ok := oldName.(string); ok { |  | ||||||
| 		oldTable = v |  | ||||||
| 	} else { |  | ||||||
| 		stmt := &gorm.Statement{DB: m.DB} |  | ||||||
| 		if err := stmt.Parse(oldName); err == nil { |  | ||||||
| 			oldTable = stmt.Table |  | ||||||
| 		} else { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := newName.(string); ok { |  | ||||||
| 		newTable = v |  | ||||||
| 	} else { |  | ||||||
| 		stmt := &gorm.Statement{DB: m.DB} |  | ||||||
| 		if err := stmt.Parse(newName); err == nil { |  | ||||||
| 			newTable = stmt.Table |  | ||||||
| 		} else { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return m.DB.Exec( |  | ||||||
| 		"sp_rename @objname = ?, @newname = ?;", |  | ||||||
| 		clause.Table{Name: oldTable}, clause.Table{Name: newTable}, |  | ||||||
| 	).Error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasColumn(value interface{}, field string) bool { |  | ||||||
| 	var count int64 |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		currentDatabase := m.DB.Migrator().CurrentDatabase() |  | ||||||
| 		name := field |  | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { |  | ||||||
| 			name = field.DBName |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", |  | ||||||
| 			currentDatabase, stmt.Table, name, |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) AlterColumn(value interface{}, field string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { |  | ||||||
| 			return m.DB.Exec( |  | ||||||
| 				"ALTER TABLE ? ALTER COLUMN ? ?", |  | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), |  | ||||||
| 			).Error |  | ||||||
| 		} |  | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if field := stmt.Schema.LookUpField(oldName); field != nil { |  | ||||||
| 			oldName = field.DBName |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field := stmt.Schema.LookUpField(newName); field != nil { |  | ||||||
| 			newName = field.DBName |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Exec( |  | ||||||
| 			"sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", |  | ||||||
| 			fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, |  | ||||||
| 		).Error |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { |  | ||||||
| 	var count int |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			name = idx.Name |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", |  | ||||||
| 			name, stmt.Table, |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Exec( |  | ||||||
| 			"sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';", |  | ||||||
| 			fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, |  | ||||||
| 		).Error |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasConstraint(value interface{}, name string) bool { |  | ||||||
| 	var count int64 |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ?  AND T.Name = ? AND I.TABLE_CATALOG = ?;`, |  | ||||||
| 			name, stmt.Table, m.CurrentDatabase(), |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) CurrentDatabase() (name string) { |  | ||||||
| 	m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| @ -1,127 +0,0 @@ | |||||||
| package mssql |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"database/sql" |  | ||||||
| 	"fmt" |  | ||||||
| 	"regexp" |  | ||||||
| 	"strconv" |  | ||||||
| 
 |  | ||||||
| 	_ "github.com/denisenkom/go-mssqldb" |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/callbacks" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/logger" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Dialector struct { |  | ||||||
| 	DSN string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Name() string { |  | ||||||
| 	return "mssql" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Open(dsn string) gorm.Dialector { |  | ||||||
| 	return &Dialector{DSN: dsn} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { |  | ||||||
| 	// register callbacks
 |  | ||||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) |  | ||||||
| 	db.Callback().Create().Replace("gorm:create", Create) |  | ||||||
| 	db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) |  | ||||||
| 
 |  | ||||||
| 	for k, v := range dialector.ClauseBuilders() { |  | ||||||
| 		db.ClauseBuilders[k] = v |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { |  | ||||||
| 	return map[string]clause.ClauseBuilder{ |  | ||||||
| 		"LIMIT": func(c clause.Clause, builder clause.Builder) { |  | ||||||
| 			if limit, ok := c.Expression.(clause.Limit); ok { |  | ||||||
| 				if limit.Offset > 0 { |  | ||||||
| 					builder.WriteString("OFFSET ") |  | ||||||
| 					builder.WriteString(strconv.Itoa(limit.Offset)) |  | ||||||
| 					builder.WriteString("ROWS") |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if limit.Limit > 0 { |  | ||||||
| 					if limit.Offset == 0 { |  | ||||||
| 						builder.WriteString(" OFFSET 0 ROWS") |  | ||||||
| 					} |  | ||||||
| 					builder.WriteString(" FETCH NEXT ") |  | ||||||
| 					builder.WriteString(strconv.Itoa(limit.Limit)) |  | ||||||
| 					builder.WriteString(" ROWS ONLY") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { |  | ||||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ |  | ||||||
| 		DB:                          db, |  | ||||||
| 		Dialector:                   dialector, |  | ||||||
| 		CreateIndexAfterCreateTable: true, |  | ||||||
| 	}}} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { |  | ||||||
| 	writer.WriteString("@p") |  | ||||||
| 	writer.WriteString(strconv.Itoa(len(stmt.Vars))) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { |  | ||||||
| 	writer.WriteByte('"') |  | ||||||
| 	writer.WriteString(str) |  | ||||||
| 	writer.WriteByte('"') |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var numericPlaceholder = regexp.MustCompile("@p(\\d+)") |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { |  | ||||||
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { |  | ||||||
| 	switch field.DataType { |  | ||||||
| 	case schema.Bool: |  | ||||||
| 		return "bit" |  | ||||||
| 	case schema.Int, schema.Uint: |  | ||||||
| 		var sqlType string |  | ||||||
| 		switch { |  | ||||||
| 		case field.Size < 16: |  | ||||||
| 			sqlType = "smallint" |  | ||||||
| 		case field.Size < 31: |  | ||||||
| 			sqlType = "int" |  | ||||||
| 		default: |  | ||||||
| 			sqlType = "bigint" |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { |  | ||||||
| 			return sqlType + " IDENTITY(1,1)" |  | ||||||
| 		} |  | ||||||
| 		return sqlType |  | ||||||
| 	case schema.Float: |  | ||||||
| 		return "float" |  | ||||||
| 	case schema.String: |  | ||||||
| 		size := field.Size |  | ||||||
| 		if field.PrimaryKey && size == 0 { |  | ||||||
| 			size = 256 |  | ||||||
| 		} |  | ||||||
| 		if size > 0 && size <= 4000 { |  | ||||||
| 			return fmt.Sprintf("nvarchar(%d)", size) |  | ||||||
| 		} |  | ||||||
| 		return "nvarchar(MAX)" |  | ||||||
| 	case schema.Time: |  | ||||||
| 		return "datetimeoffset" |  | ||||||
| 	case schema.Bytes: |  | ||||||
| 		return "varbinary(MAX)" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| @ -1,58 +0,0 @@ | |||||||
| package mysql |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Migrator struct { |  | ||||||
| 	migrator.Migrator |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) AlterColumn(value interface{}, field string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { |  | ||||||
| 			return m.DB.Exec( |  | ||||||
| 				"ALTER TABLE ? MODIFY COLUMN ? ?", |  | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), |  | ||||||
| 			).Error |  | ||||||
| 		} |  | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropTable(values ...interface{}) error { |  | ||||||
| 	values = m.ReorderModels(values, false) |  | ||||||
| 	tx := m.DB.Session(&gorm.Session{}) |  | ||||||
| 	tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") |  | ||||||
| 	for i := len(values) - 1; i >= 0; i-- { |  | ||||||
| 		if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { |  | ||||||
| 			return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error |  | ||||||
| 		}); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	tx.Exec("SET FOREIGN_KEY_CHECKS = 1;") |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropConstraint(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		for _, chk := range stmt.Schema.ParseCheckConstraints() { |  | ||||||
| 			if chk.Name == name { |  | ||||||
| 				return m.DB.Exec( |  | ||||||
| 					"ALTER TABLE ? DROP CHECK ?", |  | ||||||
| 					clause.Table{Name: stmt.Table}, clause.Column{Name: name}, |  | ||||||
| 				).Error |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Exec( |  | ||||||
| 			"ALTER TABLE ? DROP FOREIGN KEY ?", |  | ||||||
| 			clause.Table{Name: stmt.Table}, clause.Column{Name: name}, |  | ||||||
| 		).Error |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| @ -1,169 +0,0 @@ | |||||||
| package mysql |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"database/sql" |  | ||||||
| 	"fmt" |  | ||||||
| 	"math" |  | ||||||
| 
 |  | ||||||
| 	_ "github.com/go-sql-driver/mysql" |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/callbacks" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/logger" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Dialector struct { |  | ||||||
| 	DSN string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Open(dsn string) gorm.Dialector { |  | ||||||
| 	return &Dialector{DSN: dsn} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Name() string { |  | ||||||
| 	return "mysql" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { |  | ||||||
| 	// register callbacks
 |  | ||||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) |  | ||||||
| 	db.ConnPool, err = sql.Open("mysql", dialector.DSN) |  | ||||||
| 
 |  | ||||||
| 	for k, v := range dialector.ClauseBuilders() { |  | ||||||
| 		db.ClauseBuilders[k] = v |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { |  | ||||||
| 	return map[string]clause.ClauseBuilder{ |  | ||||||
| 		"ON CONFLICT": func(c clause.Clause, builder clause.Builder) { |  | ||||||
| 			if onConflict, ok := c.Expression.(clause.OnConflict); ok { |  | ||||||
| 				builder.WriteString("ON DUPLICATE KEY UPDATE ") |  | ||||||
| 				if len(onConflict.DoUpdates) == 0 { |  | ||||||
| 					if s := builder.(*gorm.Statement).Schema; s != nil { |  | ||||||
| 						var column clause.Column |  | ||||||
| 						onConflict.DoNothing = false |  | ||||||
| 
 |  | ||||||
| 						if s.PrioritizedPrimaryField != nil { |  | ||||||
| 							column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} |  | ||||||
| 						} else { |  | ||||||
| 							for _, field := range s.FieldsByDBName { |  | ||||||
| 								column = clause.Column{Name: field.DBName} |  | ||||||
| 								break |  | ||||||
| 							} |  | ||||||
| 						} |  | ||||||
| 						onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				onConflict.DoUpdates.Build(builder) |  | ||||||
| 			} else { |  | ||||||
| 				c.Build(builder) |  | ||||||
| 			} |  | ||||||
| 		}, |  | ||||||
| 		"VALUES": func(c clause.Clause, builder clause.Builder) { |  | ||||||
| 			if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { |  | ||||||
| 				builder.WriteString("VALUES()") |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			c.Build(builder) |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { |  | ||||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ |  | ||||||
| 		DB:        db, |  | ||||||
| 		Dialector: dialector, |  | ||||||
| 	}}} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { |  | ||||||
| 	writer.WriteByte('?') |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { |  | ||||||
| 	writer.WriteByte('`') |  | ||||||
| 	writer.WriteString(str) |  | ||||||
| 	writer.WriteByte('`') |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { |  | ||||||
| 	return logger.ExplainSQL(sql, nil, `"`, vars...) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { |  | ||||||
| 	switch field.DataType { |  | ||||||
| 	case schema.Bool: |  | ||||||
| 		return "boolean" |  | ||||||
| 	case schema.Int, schema.Uint: |  | ||||||
| 		sqlType := "int" |  | ||||||
| 		switch { |  | ||||||
| 		case field.Size <= 8: |  | ||||||
| 			sqlType = "tinyint" |  | ||||||
| 		case field.Size <= 16: |  | ||||||
| 			sqlType = "smallint" |  | ||||||
| 		case field.Size <= 32: |  | ||||||
| 			sqlType = "int" |  | ||||||
| 		default: |  | ||||||
| 			sqlType = "bigint" |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field.DataType == schema.Uint { |  | ||||||
| 			sqlType += " unsigned" |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { |  | ||||||
| 			sqlType += " AUTO_INCREMENT" |  | ||||||
| 		} |  | ||||||
| 		return sqlType |  | ||||||
| 	case schema.Float: |  | ||||||
| 		if field.Size <= 32 { |  | ||||||
| 			return "float" |  | ||||||
| 		} |  | ||||||
| 		return "double" |  | ||||||
| 	case schema.String: |  | ||||||
| 		size := field.Size |  | ||||||
| 		if size == 0 { |  | ||||||
| 			if field.PrimaryKey || field.HasDefaultValue { |  | ||||||
| 				size = 256 |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if size >= 65536 && size <= int(math.Pow(2, 24)) { |  | ||||||
| 			return "mediumtext" |  | ||||||
| 		} else if size > int(math.Pow(2, 24)) || size <= 0 { |  | ||||||
| 			return "longtext" |  | ||||||
| 		} |  | ||||||
| 		return fmt.Sprintf("varchar(%d)", size) |  | ||||||
| 	case schema.Time: |  | ||||||
| 		precision := "" |  | ||||||
| 		if field.Precision == 0 { |  | ||||||
| 			field.Precision = 3 |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field.Precision > 0 { |  | ||||||
| 			precision = fmt.Sprintf("(%d)", field.Precision) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field.NotNull || field.PrimaryKey { |  | ||||||
| 			return "datetime" + precision |  | ||||||
| 		} |  | ||||||
| 		return "datetime" + precision + " NULL" |  | ||||||
| 	case schema.Bytes: |  | ||||||
| 		if field.Size > 0 && field.Size < 65536 { |  | ||||||
| 			return fmt.Sprintf("varbinary(%d)", field.Size) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { |  | ||||||
| 			return "mediumblob" |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return "longblob" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| @ -1,139 +0,0 @@ | |||||||
| package postgres |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Migrator struct { |  | ||||||
| 	migrator.Migrator |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) CurrentDatabase() (name string) { |  | ||||||
| 	m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { |  | ||||||
| 	for _, opt := range opts { |  | ||||||
| 		str := stmt.Quote(opt.DBName) |  | ||||||
| 		if opt.Expression != "" { |  | ||||||
| 			str = opt.Expression |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if opt.Collate != "" { |  | ||||||
| 			str += " COLLATE " + opt.Collate |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if opt.Sort != "" { |  | ||||||
| 			str += " " + opt.Sort |  | ||||||
| 		} |  | ||||||
| 		results = append(results, clause.Expr{SQL: str}) |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { |  | ||||||
| 	var count int64 |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			name = idx.Name |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) CreateIndex(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			opts := m.BuildIndexOptions(idx.Fields, stmt) |  | ||||||
| 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} |  | ||||||
| 
 |  | ||||||
| 			createIndexSQL := "CREATE " |  | ||||||
| 			if idx.Class != "" { |  | ||||||
| 				createIndexSQL += idx.Class + " " |  | ||||||
| 			} |  | ||||||
| 			createIndexSQL += "INDEX ?" |  | ||||||
| 
 |  | ||||||
| 			if idx.Type != "" { |  | ||||||
| 				createIndexSQL += " USING " + idx.Type |  | ||||||
| 			} |  | ||||||
| 			createIndexSQL += " ON ??" |  | ||||||
| 
 |  | ||||||
| 			if idx.Where != "" { |  | ||||||
| 				createIndexSQL += " WHERE " + idx.Where |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			return m.DB.Exec(createIndexSQL, values...).Error |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return fmt.Errorf("failed to create index with name %v", name) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		return m.DB.Exec( |  | ||||||
| 			"ALTER INDEX ? RENAME TO ?", |  | ||||||
| 			clause.Column{Name: oldName}, clause.Column{Name: newName}, |  | ||||||
| 		).Error |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropIndex(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			name = idx.Name |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasTable(value interface{}) bool { |  | ||||||
| 	var count int64 |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema =  CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropTable(values ...interface{}) error { |  | ||||||
| 	values = m.ReorderModels(values, false) |  | ||||||
| 	tx := m.DB.Session(&gorm.Session{}) |  | ||||||
| 	for i := len(values) - 1; i >= 0; i-- { |  | ||||||
| 		if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { |  | ||||||
| 			return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error |  | ||||||
| 		}); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasColumn(value interface{}, field string) bool { |  | ||||||
| 	var count int64 |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		name := field |  | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { |  | ||||||
| 			name = field.DBName |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", |  | ||||||
| 			stmt.Table, name, |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| @ -1,102 +0,0 @@ | |||||||
| package postgres |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"database/sql" |  | ||||||
| 	"fmt" |  | ||||||
| 	"regexp" |  | ||||||
| 	"strconv" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/callbacks" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/logger" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| 	_ "github.com/lib/pq" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Dialector struct { |  | ||||||
| 	DSN string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Open(dsn string) gorm.Dialector { |  | ||||||
| 	return &Dialector{DSN: dsn} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Name() string { |  | ||||||
| 	return "postgres" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { |  | ||||||
| 	// register callbacks
 |  | ||||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ |  | ||||||
| 		WithReturning: true, |  | ||||||
| 	}) |  | ||||||
| 	db.ConnPool, err = sql.Open("postgres", dialector.DSN) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { |  | ||||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ |  | ||||||
| 		DB:                          db, |  | ||||||
| 		Dialector:                   dialector, |  | ||||||
| 		CreateIndexAfterCreateTable: true, |  | ||||||
| 	}}} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { |  | ||||||
| 	writer.WriteByte('$') |  | ||||||
| 	writer.WriteString(strconv.Itoa(len(stmt.Vars))) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { |  | ||||||
| 	writer.WriteByte('"') |  | ||||||
| 	writer.WriteString(str) |  | ||||||
| 	writer.WriteByte('"') |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { |  | ||||||
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { |  | ||||||
| 	switch field.DataType { |  | ||||||
| 	case schema.Bool: |  | ||||||
| 		return "boolean" |  | ||||||
| 	case schema.Int, schema.Uint: |  | ||||||
| 		if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { |  | ||||||
| 			switch { |  | ||||||
| 			case field.Size < 16: |  | ||||||
| 				return "smallserial" |  | ||||||
| 			case field.Size < 31: |  | ||||||
| 				return "serial" |  | ||||||
| 			default: |  | ||||||
| 				return "bigserial" |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			switch { |  | ||||||
| 			case field.Size < 16: |  | ||||||
| 				return "smallint" |  | ||||||
| 			case field.Size < 31: |  | ||||||
| 				return "integer" |  | ||||||
| 			default: |  | ||||||
| 				return "bigint" |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	case schema.Float: |  | ||||||
| 		return "decimal" |  | ||||||
| 	case schema.String: |  | ||||||
| 		if field.Size > 0 { |  | ||||||
| 			return fmt.Sprintf("varchar(%d)", field.Size) |  | ||||||
| 		} |  | ||||||
| 		return "text" |  | ||||||
| 	case schema.Time: |  | ||||||
| 		return "timestamptz" |  | ||||||
| 	case schema.Bytes: |  | ||||||
| 		return "bytea" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| @ -1,211 +0,0 @@ | |||||||
| package sqlite |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"regexp" |  | ||||||
| 	"strings" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Migrator struct { |  | ||||||
| 	migrator.Migrator |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasTable(value interface{}) bool { |  | ||||||
| 	var count int |  | ||||||
| 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasColumn(value interface{}, name string) bool { |  | ||||||
| 	var count int |  | ||||||
| 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { |  | ||||||
| 			name = field.DBName |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", |  | ||||||
| 			"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 	}) |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) AlterColumn(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { |  | ||||||
| 			var ( |  | ||||||
| 				createSQL    string |  | ||||||
| 				newTableName = stmt.Table + "__temp" |  | ||||||
| 			) |  | ||||||
| 
 |  | ||||||
| 			m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) |  | ||||||
| 
 |  | ||||||
| 			if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { |  | ||||||
| 				tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") |  | ||||||
| 				if err != nil { |  | ||||||
| 					return err |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) |  | ||||||
| 				createSQL = reg.ReplaceAllString(createSQL, "?") |  | ||||||
| 
 |  | ||||||
| 				var columns []string |  | ||||||
| 				columnTypes, _ := m.DB.Migrator().ColumnTypes(value) |  | ||||||
| 				for _, columnType := range columnTypes { |  | ||||||
| 					columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) |  | ||||||
| 				return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error |  | ||||||
| 			} else { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			return fmt.Errorf("failed to alter field with name %v", name) |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropColumn(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { |  | ||||||
| 			name = field.DBName |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		var ( |  | ||||||
| 			createSQL    string |  | ||||||
| 			newTableName = stmt.Table + "__temp" |  | ||||||
| 		) |  | ||||||
| 
 |  | ||||||
| 		m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) |  | ||||||
| 
 |  | ||||||
| 		if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { |  | ||||||
| 			tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) |  | ||||||
| 			createSQL = reg.ReplaceAllString(createSQL, "") |  | ||||||
| 
 |  | ||||||
| 			var columns []string |  | ||||||
| 			columnTypes, _ := m.DB.Migrator().ColumnTypes(value) |  | ||||||
| 			for _, columnType := range columnTypes { |  | ||||||
| 				if columnType.Name() != name { |  | ||||||
| 					columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) |  | ||||||
| 
 |  | ||||||
| 			return m.DB.Exec(createSQL).Error |  | ||||||
| 		} else { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) CreateConstraint(interface{}, string) error { |  | ||||||
| 	return gorm.ErrNotImplemented |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropConstraint(interface{}, string) error { |  | ||||||
| 	return gorm.ErrNotImplemented |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) CurrentDatabase() (name string) { |  | ||||||
| 	var null interface{} |  | ||||||
| 	m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { |  | ||||||
| 	for _, opt := range opts { |  | ||||||
| 		str := stmt.Quote(opt.DBName) |  | ||||||
| 		if opt.Expression != "" { |  | ||||||
| 			str = opt.Expression |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if opt.Collate != "" { |  | ||||||
| 			str += " COLLATE " + opt.Collate |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if opt.Sort != "" { |  | ||||||
| 			str += " " + opt.Sort |  | ||||||
| 		} |  | ||||||
| 		results = append(results, clause.Expr{SQL: str}) |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) CreateIndex(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			opts := m.BuildIndexOptions(idx.Fields, stmt) |  | ||||||
| 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} |  | ||||||
| 
 |  | ||||||
| 			createIndexSQL := "CREATE " |  | ||||||
| 			if idx.Class != "" { |  | ||||||
| 				createIndexSQL += idx.Class + " " |  | ||||||
| 			} |  | ||||||
| 			createIndexSQL += "INDEX ?" |  | ||||||
| 
 |  | ||||||
| 			if idx.Type != "" { |  | ||||||
| 				createIndexSQL += " USING " + idx.Type |  | ||||||
| 			} |  | ||||||
| 			createIndexSQL += " ON ??" |  | ||||||
| 
 |  | ||||||
| 			if idx.Where != "" { |  | ||||||
| 				createIndexSQL += " WHERE " + idx.Where |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			return m.DB.Exec(createIndexSQL, values...).Error |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return fmt.Errorf("failed to create index with name %v", name) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { |  | ||||||
| 	var count int |  | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			name = idx.Name |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		m.DB.Raw( |  | ||||||
| 			"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, |  | ||||||
| 		).Row().Scan(&count) |  | ||||||
| 		return nil |  | ||||||
| 	}) |  | ||||||
| 	return count > 0 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		var sql string |  | ||||||
| 		m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) |  | ||||||
| 		if sql != "" { |  | ||||||
| 			return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error |  | ||||||
| 		} |  | ||||||
| 		return fmt.Errorf("failed to find index with name %v", oldName) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Migrator) DropIndex(value interface{}, name string) error { |  | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { |  | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { |  | ||||||
| 			name = idx.Name |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| @ -1,80 +0,0 @@ | |||||||
| package sqlite |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"database/sql" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/callbacks" |  | ||||||
| 	"gorm.io/gorm/clause" |  | ||||||
| 	"gorm.io/gorm/logger" |  | ||||||
| 	"gorm.io/gorm/migrator" |  | ||||||
| 	"gorm.io/gorm/schema" |  | ||||||
| 	_ "github.com/mattn/go-sqlite3" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Dialector struct { |  | ||||||
| 	DSN string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Open(dsn string) gorm.Dialector { |  | ||||||
| 	return &Dialector{DSN: dsn} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Name() string { |  | ||||||
| 	return "sqlite" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { |  | ||||||
| 	// register callbacks
 |  | ||||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ |  | ||||||
| 		LastInsertIDReversed: true, |  | ||||||
| 	}) |  | ||||||
| 	db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { |  | ||||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ |  | ||||||
| 		DB:                          db, |  | ||||||
| 		Dialector:                   dialector, |  | ||||||
| 		CreateIndexAfterCreateTable: true, |  | ||||||
| 	}}} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { |  | ||||||
| 	writer.WriteByte('?') |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { |  | ||||||
| 	writer.WriteByte('`') |  | ||||||
| 	writer.WriteString(str) |  | ||||||
| 	writer.WriteByte('`') |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { |  | ||||||
| 	return logger.ExplainSQL(sql, nil, `"`, vars...) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { |  | ||||||
| 	switch field.DataType { |  | ||||||
| 	case schema.Bool: |  | ||||||
| 		return "numeric" |  | ||||||
| 	case schema.Int, schema.Uint: |  | ||||||
| 		if field.AutoIncrement { |  | ||||||
| 			// https://www.sqlite.org/autoinc.html
 |  | ||||||
| 			return "integer PRIMARY KEY AUTOINCREMENT" |  | ||||||
| 		} else { |  | ||||||
| 			return "integer" |  | ||||||
| 		} |  | ||||||
| 	case schema.Float: |  | ||||||
| 		return "real" |  | ||||||
| 	case schema.String: |  | ||||||
| 		return "text" |  | ||||||
| 	case schema.Time: |  | ||||||
| 		return "datetime" |  | ||||||
| 	case schema.Bytes: |  | ||||||
| 		return "blob" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							| @ -3,12 +3,6 @@ module gorm.io/gorm | |||||||
| go 1.14 | go 1.14 | ||||||
| 
 | 
 | ||||||
| require ( | require ( | ||||||
| 	github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc |  | ||||||
| 	github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 |  | ||||||
| 	github.com/go-sql-driver/mysql v1.5.0 |  | ||||||
| 	github.com/jinzhu/inflection v1.0.0 | 	github.com/jinzhu/inflection v1.0.0 | ||||||
| 	github.com/jinzhu/now v1.1.1 | 	github.com/jinzhu/now v1.1.1 | ||||||
| 	github.com/lib/pq v1.1.1 |  | ||||||
| 	github.com/mattn/go-sqlite3 v2.0.1+incompatible |  | ||||||
| 	gorm.io/gorm v1.9.12 |  | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestFieldValuerAndSetter(t *testing.T) { | func TestFieldValuerAndSetter(t *testing.T) { | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type User struct { | type User struct { | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/tests" | 	"gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestParseSchema(t *testing.T) { | func TestParseSchema(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestBelongsToAssociation(t *testing.T) { | func TestBelongsToAssociation(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestHasManyAssociation(t *testing.T) { | func TestHasManyAssociation(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestHasOneAssociation(t *testing.T) { | func TestHasOneAssociation(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMany2ManyAssociation(t *testing.T) { | func TestMany2ManyAssociation(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { | func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestCount(t *testing.T) { | func TestCount(t *testing.T) { | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/now" | 	"github.com/jinzhu/now" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestCreate(t *testing.T) { | func TestCreate(t *testing.T) { | ||||||
|  | |||||||
| @ -3,8 +3,6 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 |  | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestCustomizeColumn(t *testing.T) { | func TestCustomizeColumn(t *testing.T) { | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestDelete(t *testing.T) { | func TestDelete(t *testing.T) { | ||||||
|  | |||||||
| @ -4,7 +4,6 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestEmbeddedStruct(t *testing.T) { | func TestEmbeddedStruct(t *testing.T) { | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								tests/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								tests/go.mod
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | |||||||
|  | module gorm.io/gorm/tests | ||||||
|  | 
 | ||||||
|  | go 1.14 | ||||||
|  | 
 | ||||||
|  | require ( | ||||||
|  | 	github.com/jinzhu/now v1.1.1 | ||||||
|  | 	gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 | ||||||
|  | 	gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 | ||||||
|  | 	gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 | ||||||
|  | 	gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 | ||||||
|  | 	gorm.io/gorm v1.9.12 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | replace gorm.io/gorm => ../ | ||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestGroupBy(t *testing.T) { | func TestGroupBy(t *testing.T) { | ||||||
|  | |||||||
| @ -1,17 +1,13 @@ | |||||||
| package tests | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" |  | ||||||
| 	"fmt" |  | ||||||
| 	"go/ast" |  | ||||||
| 	"reflect" |  | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/utils" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Config struct { | type Config struct { | ||||||
| @ -73,101 +69,6 @@ func GetUser(name string, config Config) *User { | |||||||
| 	return &user | 	return &user | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { |  | ||||||
| 	for _, name := range names { |  | ||||||
| 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() |  | ||||||
| 		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() |  | ||||||
| 		t.Run(name, func(t *testing.T) { |  | ||||||
| 			AssertEqual(t, got, expect) |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func AssertEqual(t *testing.T, got, expect interface{}) { |  | ||||||
| 	if !reflect.DeepEqual(got, expect) { |  | ||||||
| 		isEqual := func() { |  | ||||||
| 			if curTime, ok := got.(time.Time); ok { |  | ||||||
| 				format := "2006-01-02T15:04:05Z07:00" |  | ||||||
| 
 |  | ||||||
| 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { |  | ||||||
| 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) |  | ||||||
| 				} |  | ||||||
| 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { |  | ||||||
| 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if fmt.Sprint(got) == fmt.Sprint(expect) { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { |  | ||||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if valuer, ok := got.(driver.Valuer); ok { |  | ||||||
| 			got, _ = valuer.Value() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if valuer, ok := expect.(driver.Valuer); ok { |  | ||||||
| 			expect, _ = valuer.Value() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if got != nil { |  | ||||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if expect != nil { |  | ||||||
| 			expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { |  | ||||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if reflect.ValueOf(got).Kind() == reflect.Slice { |  | ||||||
| 			if reflect.ValueOf(expect).Kind() == reflect.Slice { |  | ||||||
| 				if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { |  | ||||||
| 					for i := 0; i < reflect.ValueOf(got).Len(); i++ { |  | ||||||
| 						name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) |  | ||||||
| 						t.Run(name, func(t *testing.T) { |  | ||||||
| 							AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) |  | ||||||
| 						}) |  | ||||||
| 					} |  | ||||||
| 				} else { |  | ||||||
| 					name := reflect.ValueOf(got).Type().Elem().Name() |  | ||||||
| 					t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) |  | ||||||
| 				} |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if reflect.ValueOf(got).Kind() == reflect.Struct { |  | ||||||
| 			if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { |  | ||||||
| 				for i := 0; i < reflect.ValueOf(got).NumField(); i++ { |  | ||||||
| 					if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { |  | ||||||
| 						field := reflect.ValueOf(got).Field(i) |  | ||||||
| 						t.Run(fieldStruct.Name, func(t *testing.T) { |  | ||||||
| 							AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) |  | ||||||
| 						}) |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { |  | ||||||
| 			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() |  | ||||||
| 			isEqual() |  | ||||||
| 		} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { |  | ||||||
| 			expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() |  | ||||||
| 			isEqual() |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func CheckPet(t *testing.T, pet Pet, expect Pet) { | func CheckPet(t *testing.T, pet Pet, expect Pet) { | ||||||
| 	if pet.ID != 0 { | 	if pet.ID != 0 { | ||||||
| 		var newPet Pet | 		var newPet Pet | ||||||
| @ -6,7 +6,6 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Product struct { | type Product struct { | ||||||
|  | |||||||
| @ -5,7 +5,6 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Person struct { | type Person struct { | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestJoins(t *testing.T) { | func TestJoins(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMain(m *testing.M) { | func TestMain(m *testing.M) { | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMigrate(t *testing.T) { | func TestMigrate(t *testing.T) { | ||||||
|  | |||||||
| @ -4,8 +4,6 @@ import ( | |||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 |  | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Blog struct { | type Blog struct { | ||||||
| @ -36,8 +34,8 @@ func compareTags(tags []Tag, contents []string) bool { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | ||||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | ||||||
| @ -125,8 +123,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | ||||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | ||||||
| @ -246,8 +244,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { | func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { | ||||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Hamster struct { | type Hamster struct { | ||||||
|  | |||||||
| @ -3,8 +3,6 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 |  | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Animal struct { | type Animal struct { | ||||||
|  | |||||||
| @ -7,7 +7,6 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func toJSONString(v interface{}) []byte { | func toJSONString(v interface{}) []byte { | ||||||
| @ -691,8 +690,8 @@ func TestNestedPreload12(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { | func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { | ||||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	type ( | 	type ( | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestNestedPreload(t *testing.T) { | func TestNestedPreload(t *testing.T) { | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestFind(t *testing.T) { | func TestFind(t *testing.T) { | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestScan(t *testing.T) { | func TestScan(t *testing.T) { | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestScannerValuer(t *testing.T) { | func TestScannerValuer(t *testing.T) { | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func NameIn1And2(d *gorm.DB) *gorm.DB { | func NameIn1And2(d *gorm.DB) *gorm.DB { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestSoftDelete(t *testing.T) { | func TestSoftDelete(t *testing.T) { | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestRow(t *testing.T) { | func TestRow(t *testing.T) { | ||||||
|  | |||||||
| @ -18,8 +18,13 @@ for dialect in "${dialects[@]}" ; do | |||||||
|     if [ "$GORM_VERBOSE" = "" ] |     if [ "$GORM_VERBOSE" = "" ] | ||||||
|     then |     then | ||||||
|       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... |       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... | ||||||
|  |       cd tests | ||||||
|  |       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... | ||||||
|     else |     else | ||||||
|       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... |       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... | ||||||
|  |       cd tests | ||||||
|  |       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... | ||||||
|     fi |     fi | ||||||
|  |     cd .. | ||||||
|   fi |   fi | ||||||
| done | done | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| package tests | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"log" | 	"log" | ||||||
| @ -7,12 +7,13 @@ import ( | |||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"gorm.io/driver/mysql" | ||||||
|  | 	"gorm.io/driver/postgres" | ||||||
|  | 	"gorm.io/driver/sqlite" | ||||||
|  | 	"gorm.io/driver/sqlserver" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/dialects/mssql" |  | ||||||
| 	"gorm.io/gorm/dialects/mysql" |  | ||||||
| 	"gorm.io/gorm/dialects/postgres" |  | ||||||
| 	"gorm.io/gorm/dialects/sqlite" |  | ||||||
| 	"gorm.io/gorm/logger" | 	"gorm.io/gorm/logger" | ||||||
|  | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var DB *gorm.DB | var DB *gorm.DB | ||||||
| @ -40,17 +41,17 @@ func OpenTestConnection() (db *gorm.DB, err error) { | |||||||
| 			dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" | 			dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" | ||||||
| 		} | 		} | ||||||
| 		db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) | 		db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) | ||||||
| 	case "mssql": | 	case "sqlserver": | ||||||
| 		// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
 | 		// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
 | ||||||
| 		// CREATE DATABASE gorm;
 | 		// CREATE DATABASE gorm;
 | ||||||
| 		// USE gorm;
 | 		// USE gorm;
 | ||||||
| 		// CREATE USER gorm FROM LOGIN gorm;
 | 		// CREATE USER gorm FROM LOGIN gorm;
 | ||||||
| 		// sp_changedbowner 'gorm';
 | 		// sp_changedbowner 'gorm';
 | ||||||
| 		log.Println("testing mssql...") | 		log.Println("testing sqlserver...") | ||||||
| 		if dbDSN == "" { | 		if dbDSN == "" { | ||||||
| 			dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" | 			dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" | ||||||
| 		} | 		} | ||||||
| 		db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) | 		db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) | ||||||
| 	default: | 	default: | ||||||
| 		log.Println("testing sqlite3...") | 		log.Println("testing sqlite3...") | ||||||
| 		db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) | 		db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) | ||||||
| @ -90,8 +91,3 @@ func RunMigrations() { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func Now() *time.Time { |  | ||||||
| 	now := time.Now() |  | ||||||
| 	return &now |  | ||||||
| } |  | ||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestTransaction(t *testing.T) { | func TestTransaction(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestUpdateBelongsTo(t *testing.T) { | func TestUpdateBelongsTo(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestUpdateHasManyAssociations(t *testing.T) { | func TestUpdateHasManyAssociations(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestUpdateHasOne(t *testing.T) { | func TestUpdateHasOne(t *testing.T) { | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestUpdateMany2ManyAssociations(t *testing.T) { | func TestUpdateMany2ManyAssociations(t *testing.T) { | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestUpdate(t *testing.T) { | func TestUpdate(t *testing.T) { | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	. "gorm.io/gorm/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestUpsert(t *testing.T) { | func TestUpsert(t *testing.T) { | ||||||
|  | |||||||
							
								
								
									
										112
									
								
								utils/tests/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								utils/tests/utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,112 @@ | |||||||
|  | package tests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"fmt" | ||||||
|  | 	"go/ast" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"gorm.io/gorm/utils" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { | ||||||
|  | 	for _, name := range names { | ||||||
|  | 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||||
|  | 		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			AssertEqual(t, got, expect) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func AssertEqual(t *testing.T, got, expect interface{}) { | ||||||
|  | 	if !reflect.DeepEqual(got, expect) { | ||||||
|  | 		isEqual := func() { | ||||||
|  | 			if curTime, ok := got.(time.Time); ok { | ||||||
|  | 				format := "2006-01-02T15:04:05Z07:00" | ||||||
|  | 
 | ||||||
|  | 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { | ||||||
|  | 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) | ||||||
|  | 				} | ||||||
|  | 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | ||||||
|  | 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if fmt.Sprint(got) == fmt.Sprint(expect) { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { | ||||||
|  | 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if valuer, ok := got.(driver.Valuer); ok { | ||||||
|  | 			got, _ = valuer.Value() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if valuer, ok := expect.(driver.Valuer); ok { | ||||||
|  | 			expect, _ = valuer.Value() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if got != nil { | ||||||
|  | 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if expect != nil { | ||||||
|  | 			expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { | ||||||
|  | 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if reflect.ValueOf(got).Kind() == reflect.Slice { | ||||||
|  | 			if reflect.ValueOf(expect).Kind() == reflect.Slice { | ||||||
|  | 				if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { | ||||||
|  | 					for i := 0; i < reflect.ValueOf(got).Len(); i++ { | ||||||
|  | 						name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) | ||||||
|  | 						t.Run(name, func(t *testing.T) { | ||||||
|  | 							AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					name := reflect.ValueOf(got).Type().Elem().Name() | ||||||
|  | 					t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) | ||||||
|  | 				} | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if reflect.ValueOf(got).Kind() == reflect.Struct { | ||||||
|  | 			if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { | ||||||
|  | 				for i := 0; i < reflect.ValueOf(got).NumField(); i++ { | ||||||
|  | 					if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { | ||||||
|  | 						field := reflect.ValueOf(got).Field(i) | ||||||
|  | 						t.Run(fieldStruct.Name, func(t *testing.T) { | ||||||
|  | 							AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { | ||||||
|  | 			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() | ||||||
|  | 			isEqual() | ||||||
|  | 		} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { | ||||||
|  | 			expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() | ||||||
|  | 			isEqual() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Now() *time.Time { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return &now | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu