Refactor tests files
This commit is contained in:
		
							parent
							
								
									5790ba9ef4
								
							
						
					
					
						commit
						8bb05a5a69
					
				| @ -7,7 +7,7 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func BenchmarkSelect(b *testing.B) { | ||||
|  | ||||
| @ -9,7 +9,7 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| var db, _ = gorm.Open(tests.DummyDialector{}, nil) | ||||
|  | ||||
| @ -8,7 +8,7 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestExpr(t *testing.T) { | ||||
|  | ||||
| @ -1,225 +0,0 @@ | ||||
| package mssql | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/callbacks" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| func Create(db *gorm.DB) { | ||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||
| 		for _, c := range db.Statement.Schema.CreateClauses { | ||||
| 			db.Statement.AddClause(c) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if db.Statement.SQL.String() == "" { | ||||
| 		setIdentityInsert := false | ||||
| 		c := db.Statement.Clauses["ON CONFLICT"] | ||||
| 		onConflict, hasConflict := c.Expression.(clause.OnConflict) | ||||
| 
 | ||||
| 		if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { | ||||
| 			setIdentityInsert = false | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Struct: | ||||
| 				_, isZero := field.ValueOf(db.Statement.ReflectValue) | ||||
| 				setIdentityInsert = !isZero | ||||
| 			case reflect.Slice: | ||||
| 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 					_, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)) | ||||
| 					setIdentityInsert = !isZero | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) { | ||||
| 				setIdentityInsert = true | ||||
| 				db.Statement.WriteString("SET IDENTITY_INSERT ") | ||||
| 				db.Statement.WriteQuoted(db.Statement.Table) | ||||
| 				db.Statement.WriteString(" ON;") | ||||
| 			} else { | ||||
| 				setIdentityInsert = false | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 { | ||||
| 			MergeCreate(db, onConflict) | ||||
| 		} else { | ||||
| 			db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) | ||||
| 			db.Statement.Build("INSERT") | ||||
| 			db.Statement.WriteByte(' ') | ||||
| 
 | ||||
| 			db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) | ||||
| 			if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok { | ||||
| 				if len(values.Columns) > 0 { | ||||
| 					db.Statement.WriteByte('(') | ||||
| 					for idx, column := range values.Columns { | ||||
| 						if idx > 0 { | ||||
| 							db.Statement.WriteByte(',') | ||||
| 						} | ||||
| 						db.Statement.WriteQuoted(column) | ||||
| 					} | ||||
| 					db.Statement.WriteByte(')') | ||||
| 
 | ||||
| 					outputInserted(db) | ||||
| 
 | ||||
| 					db.Statement.WriteString(" VALUES ") | ||||
| 
 | ||||
| 					for idx, value := range values.Values { | ||||
| 						if idx > 0 { | ||||
| 							db.Statement.WriteByte(',') | ||||
| 						} | ||||
| 
 | ||||
| 						db.Statement.WriteByte('(') | ||||
| 						db.Statement.AddVar(db.Statement, value...) | ||||
| 						db.Statement.WriteByte(')') | ||||
| 					} | ||||
| 
 | ||||
| 					db.Statement.WriteString(";") | ||||
| 				} else { | ||||
| 					db.Statement.WriteString("DEFAULT VALUES;") | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if setIdentityInsert { | ||||
| 			db.Statement.WriteString("SET IDENTITY_INSERT ") | ||||
| 			db.Statement.WriteQuoted(db.Statement.Table) | ||||
| 			db.Statement.WriteString(" OFF;") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !db.DryRun { | ||||
| 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 		if err == nil { | ||||
| 			defer rows.Close() | ||||
| 
 | ||||
| 			if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { | ||||
| 				sortedKeys := []string{} | ||||
| 				for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { | ||||
| 					sortedKeys = append(sortedKeys, field.DBName) | ||||
| 				} | ||||
| 				sort.Strings(sortedKeys) | ||||
| 
 | ||||
| 				returnningFields := make([]*schema.Field, len(sortedKeys)) | ||||
| 				for idx, key := range sortedKeys { | ||||
| 					returnningFields[idx] = db.Statement.Schema.LookUpField(key) | ||||
| 				} | ||||
| 
 | ||||
| 				values := make([]interface{}, len(returnningFields)) | ||||
| 
 | ||||
| 				switch db.Statement.ReflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for rows.Next() { | ||||
| 						for idx, field := range returnningFields { | ||||
| 							values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||
| 						} | ||||
| 
 | ||||
| 						db.RowsAffected++ | ||||
| 						db.AddError(rows.Scan(values...)) | ||||
| 					} | ||||
| 				case reflect.Struct: | ||||
| 					for idx, field := range returnningFields { | ||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 					} | ||||
| 
 | ||||
| 					if rows.Next() { | ||||
| 						db.RowsAffected++ | ||||
| 						db.AddError(rows.Scan(values...)) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			db.AddError(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { | ||||
| 	values := callbacks.ConvertToCreateValues(db.Statement) | ||||
| 
 | ||||
| 	db.Statement.WriteString("MERGE INTO ") | ||||
| 	db.Statement.WriteQuoted(db.Statement.Table) | ||||
| 	db.Statement.WriteString(" USING (VALUES") | ||||
| 	for idx, value := range values.Values { | ||||
| 		if idx > 0 { | ||||
| 			db.Statement.WriteByte(',') | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.WriteByte('(') | ||||
| 		db.Statement.AddVar(db.Statement, value...) | ||||
| 		db.Statement.WriteByte(')') | ||||
| 	} | ||||
| 
 | ||||
| 	db.Statement.WriteString(") AS source (") | ||||
| 	for idx, column := range values.Columns { | ||||
| 		if idx > 0 { | ||||
| 			db.Statement.WriteByte(',') | ||||
| 		} | ||||
| 		db.Statement.WriteQuoted(column.Name) | ||||
| 	} | ||||
| 	db.Statement.WriteString(") ON ") | ||||
| 
 | ||||
| 	var where clause.Where | ||||
| 	for _, field := range db.Statement.Schema.PrimaryFields { | ||||
| 		where.Exprs = append(where.Exprs, clause.Eq{ | ||||
| 			Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, | ||||
| 			Value:  clause.Column{Table: "source", Name: field.DBName}, | ||||
| 		}) | ||||
| 	} | ||||
| 	where.Build(db.Statement) | ||||
| 
 | ||||
| 	if len(onConflict.DoUpdates) > 0 { | ||||
| 		db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") | ||||
| 		onConflict.DoUpdates.Build(db.Statement) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") | ||||
| 
 | ||||
| 	for idx, column := range values.Columns { | ||||
| 		if idx > 0 { | ||||
| 			db.Statement.WriteByte(',') | ||||
| 		} | ||||
| 		db.Statement.WriteQuoted(column.Name) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Statement.WriteString(") VALUES (") | ||||
| 
 | ||||
| 	for idx, column := range values.Columns { | ||||
| 		if idx > 0 { | ||||
| 			db.Statement.WriteByte(',') | ||||
| 		} | ||||
| 		db.Statement.WriteQuoted(clause.Column{ | ||||
| 			Table: "source", | ||||
| 			Name:  column.Name, | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Statement.WriteString(")") | ||||
| 	outputInserted(db) | ||||
| 	db.Statement.WriteString(";") | ||||
| } | ||||
| 
 | ||||
| func outputInserted(db *gorm.DB) { | ||||
| 	if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { | ||||
| 		sortedKeys := []string{} | ||||
| 		for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { | ||||
| 			sortedKeys = append(sortedKeys, field.DBName) | ||||
| 		} | ||||
| 		sort.Strings(sortedKeys) | ||||
| 
 | ||||
| 		db.Statement.WriteString(" OUTPUT") | ||||
| 		for idx, key := range sortedKeys { | ||||
| 			if idx > 0 { | ||||
| 				db.Statement.WriteString(",") | ||||
| 			} | ||||
| 			db.Statement.WriteString(" INSERTED.") | ||||
| 			db.Statement.AddVar(db.Statement, clause.Column{Name: key}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,142 +0,0 @@ | ||||
| package mssql | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| ) | ||||
| 
 | ||||
| type Migrator struct { | ||||
| 	migrator.Migrator | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasTable(value interface{}) bool { | ||||
| 	var count int | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", | ||||
| 			stmt.Table, m.CurrentDatabase(), | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) RenameTable(oldName, newName interface{}) error { | ||||
| 	var oldTable, newTable string | ||||
| 	if v, ok := oldName.(string); ok { | ||||
| 		oldTable = v | ||||
| 	} else { | ||||
| 		stmt := &gorm.Statement{DB: m.DB} | ||||
| 		if err := stmt.Parse(oldName); err == nil { | ||||
| 			oldTable = stmt.Table | ||||
| 		} else { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if v, ok := newName.(string); ok { | ||||
| 		newTable = v | ||||
| 	} else { | ||||
| 		stmt := &gorm.Statement{DB: m.DB} | ||||
| 		if err := stmt.Parse(newName); err == nil { | ||||
| 			newTable = stmt.Table | ||||
| 		} else { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return m.DB.Exec( | ||||
| 		"sp_rename @objname = ?, @newname = ?;", | ||||
| 		clause.Table{Name: oldTable}, clause.Table{Name: newTable}, | ||||
| 	).Error | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasColumn(value interface{}, field string) bool { | ||||
| 	var count int64 | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		currentDatabase := m.DB.Migrator().CurrentDatabase() | ||||
| 		name := field | ||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||
| 			name = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", | ||||
| 			currentDatabase, stmt.Table, name, | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 
 | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) AlterColumn(value interface{}, field string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||
| 			return m.DB.Exec( | ||||
| 				"ALTER TABLE ? ALTER COLUMN ? ?", | ||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), | ||||
| 			).Error | ||||
| 		} | ||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if field := stmt.Schema.LookUpField(oldName); field != nil { | ||||
| 			oldName = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		if field := stmt.Schema.LookUpField(newName); field != nil { | ||||
| 			newName = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Exec( | ||||
| 			"sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", | ||||
| 			fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, | ||||
| 		).Error | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||
| 	var count int | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			name = idx.Name | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", | ||||
| 			name, stmt.Table, | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 
 | ||||
| 		return m.DB.Exec( | ||||
| 			"sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';", | ||||
| 			fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, | ||||
| 		).Error | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasConstraint(value interface{}, name string) bool { | ||||
| 	var count int64 | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		return m.DB.Raw( | ||||
| 			`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ?  AND T.Name = ? AND I.TABLE_CATALOG = ?;`, | ||||
| 			name, stmt.Table, m.CurrentDatabase(), | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CurrentDatabase() (name string) { | ||||
| 	m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) | ||||
| 	return | ||||
| } | ||||
| @ -1,127 +0,0 @@ | ||||
| package mssql | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	_ "github.com/denisenkom/go-mssqldb" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/callbacks" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| type Dialector struct { | ||||
| 	DSN string | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Name() string { | ||||
| 	return "mssql" | ||||
| } | ||||
| 
 | ||||
| func Open(dsn string) gorm.Dialector { | ||||
| 	return &Dialector{DSN: dsn} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) | ||||
| 	db.Callback().Create().Replace("gorm:create", Create) | ||||
| 	db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) | ||||
| 
 | ||||
| 	for k, v := range dialector.ClauseBuilders() { | ||||
| 		db.ClauseBuilders[k] = v | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { | ||||
| 	return map[string]clause.ClauseBuilder{ | ||||
| 		"LIMIT": func(c clause.Clause, builder clause.Builder) { | ||||
| 			if limit, ok := c.Expression.(clause.Limit); ok { | ||||
| 				if limit.Offset > 0 { | ||||
| 					builder.WriteString("OFFSET ") | ||||
| 					builder.WriteString(strconv.Itoa(limit.Offset)) | ||||
| 					builder.WriteString("ROWS") | ||||
| 				} | ||||
| 
 | ||||
| 				if limit.Limit > 0 { | ||||
| 					if limit.Offset == 0 { | ||||
| 						builder.WriteString(" OFFSET 0 ROWS") | ||||
| 					} | ||||
| 					builder.WriteString(" FETCH NEXT ") | ||||
| 					builder.WriteString(strconv.Itoa(limit.Limit)) | ||||
| 					builder.WriteString(" ROWS ONLY") | ||||
| 				} | ||||
| 			} | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:                          db, | ||||
| 		Dialector:                   dialector, | ||||
| 		CreateIndexAfterCreateTable: true, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { | ||||
| 	writer.WriteString("@p") | ||||
| 	writer.WriteString(strconv.Itoa(len(stmt.Vars))) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { | ||||
| 	writer.WriteByte('"') | ||||
| 	writer.WriteString(str) | ||||
| 	writer.WriteByte('"') | ||||
| } | ||||
| 
 | ||||
| var numericPlaceholder = regexp.MustCompile("@p(\\d+)") | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
| 		return "bit" | ||||
| 	case schema.Int, schema.Uint: | ||||
| 		var sqlType string | ||||
| 		switch { | ||||
| 		case field.Size < 16: | ||||
| 			sqlType = "smallint" | ||||
| 		case field.Size < 31: | ||||
| 			sqlType = "int" | ||||
| 		default: | ||||
| 			sqlType = "bigint" | ||||
| 		} | ||||
| 
 | ||||
| 		if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { | ||||
| 			return sqlType + " IDENTITY(1,1)" | ||||
| 		} | ||||
| 		return sqlType | ||||
| 	case schema.Float: | ||||
| 		return "float" | ||||
| 	case schema.String: | ||||
| 		size := field.Size | ||||
| 		if field.PrimaryKey && size == 0 { | ||||
| 			size = 256 | ||||
| 		} | ||||
| 		if size > 0 && size <= 4000 { | ||||
| 			return fmt.Sprintf("nvarchar(%d)", size) | ||||
| 		} | ||||
| 		return "nvarchar(MAX)" | ||||
| 	case schema.Time: | ||||
| 		return "datetimeoffset" | ||||
| 	case schema.Bytes: | ||||
| 		return "varbinary(MAX)" | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
| @ -1,58 +0,0 @@ | ||||
| package mysql | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| ) | ||||
| 
 | ||||
| type Migrator struct { | ||||
| 	migrator.Migrator | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) AlterColumn(value interface{}, field string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||
| 			return m.DB.Exec( | ||||
| 				"ALTER TABLE ? MODIFY COLUMN ? ?", | ||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), | ||||
| 			).Error | ||||
| 		} | ||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropTable(values ...interface{}) error { | ||||
| 	values = m.ReorderModels(values, false) | ||||
| 	tx := m.DB.Session(&gorm.Session{}) | ||||
| 	tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") | ||||
| 	for i := len(values) - 1; i >= 0; i-- { | ||||
| 		if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { | ||||
| 			return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error | ||||
| 		}); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	tx.Exec("SET FOREIGN_KEY_CHECKS = 1;") | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropConstraint(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		for _, chk := range stmt.Schema.ParseCheckConstraints() { | ||||
| 			if chk.Name == name { | ||||
| 				return m.DB.Exec( | ||||
| 					"ALTER TABLE ? DROP CHECK ?", | ||||
| 					clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | ||||
| 				).Error | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Exec( | ||||
| 			"ALTER TABLE ? DROP FOREIGN KEY ?", | ||||
| 			clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | ||||
| 		).Error | ||||
| 	}) | ||||
| } | ||||
| @ -1,169 +0,0 @@ | ||||
| package mysql | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 
 | ||||
| 	_ "github.com/go-sql-driver/mysql" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/callbacks" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| type Dialector struct { | ||||
| 	DSN string | ||||
| } | ||||
| 
 | ||||
| func Open(dsn string) gorm.Dialector { | ||||
| 	return &Dialector{DSN: dsn} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Name() string { | ||||
| 	return "mysql" | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) | ||||
| 	db.ConnPool, err = sql.Open("mysql", dialector.DSN) | ||||
| 
 | ||||
| 	for k, v := range dialector.ClauseBuilders() { | ||||
| 		db.ClauseBuilders[k] = v | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { | ||||
| 	return map[string]clause.ClauseBuilder{ | ||||
| 		"ON CONFLICT": func(c clause.Clause, builder clause.Builder) { | ||||
| 			if onConflict, ok := c.Expression.(clause.OnConflict); ok { | ||||
| 				builder.WriteString("ON DUPLICATE KEY UPDATE ") | ||||
| 				if len(onConflict.DoUpdates) == 0 { | ||||
| 					if s := builder.(*gorm.Statement).Schema; s != nil { | ||||
| 						var column clause.Column | ||||
| 						onConflict.DoNothing = false | ||||
| 
 | ||||
| 						if s.PrioritizedPrimaryField != nil { | ||||
| 							column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} | ||||
| 						} else { | ||||
| 							for _, field := range s.FieldsByDBName { | ||||
| 								column = clause.Column{Name: field.DBName} | ||||
| 								break | ||||
| 							} | ||||
| 						} | ||||
| 						onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				onConflict.DoUpdates.Build(builder) | ||||
| 			} else { | ||||
| 				c.Build(builder) | ||||
| 			} | ||||
| 		}, | ||||
| 		"VALUES": func(c clause.Clause, builder clause.Builder) { | ||||
| 			if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { | ||||
| 				builder.WriteString("VALUES()") | ||||
| 				return | ||||
| 			} | ||||
| 			c.Build(builder) | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:        db, | ||||
| 		Dialector: dialector, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { | ||||
| 	writer.WriteByte('?') | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { | ||||
| 	writer.WriteByte('`') | ||||
| 	writer.WriteString(str) | ||||
| 	writer.WriteByte('`') | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, nil, `"`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
| 		return "boolean" | ||||
| 	case schema.Int, schema.Uint: | ||||
| 		sqlType := "int" | ||||
| 		switch { | ||||
| 		case field.Size <= 8: | ||||
| 			sqlType = "tinyint" | ||||
| 		case field.Size <= 16: | ||||
| 			sqlType = "smallint" | ||||
| 		case field.Size <= 32: | ||||
| 			sqlType = "int" | ||||
| 		default: | ||||
| 			sqlType = "bigint" | ||||
| 		} | ||||
| 
 | ||||
| 		if field.DataType == schema.Uint { | ||||
| 			sqlType += " unsigned" | ||||
| 		} | ||||
| 
 | ||||
| 		if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { | ||||
| 			sqlType += " AUTO_INCREMENT" | ||||
| 		} | ||||
| 		return sqlType | ||||
| 	case schema.Float: | ||||
| 		if field.Size <= 32 { | ||||
| 			return "float" | ||||
| 		} | ||||
| 		return "double" | ||||
| 	case schema.String: | ||||
| 		size := field.Size | ||||
| 		if size == 0 { | ||||
| 			if field.PrimaryKey || field.HasDefaultValue { | ||||
| 				size = 256 | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if size >= 65536 && size <= int(math.Pow(2, 24)) { | ||||
| 			return "mediumtext" | ||||
| 		} else if size > int(math.Pow(2, 24)) || size <= 0 { | ||||
| 			return "longtext" | ||||
| 		} | ||||
| 		return fmt.Sprintf("varchar(%d)", size) | ||||
| 	case schema.Time: | ||||
| 		precision := "" | ||||
| 		if field.Precision == 0 { | ||||
| 			field.Precision = 3 | ||||
| 		} | ||||
| 
 | ||||
| 		if field.Precision > 0 { | ||||
| 			precision = fmt.Sprintf("(%d)", field.Precision) | ||||
| 		} | ||||
| 
 | ||||
| 		if field.NotNull || field.PrimaryKey { | ||||
| 			return "datetime" + precision | ||||
| 		} | ||||
| 		return "datetime" + precision + " NULL" | ||||
| 	case schema.Bytes: | ||||
| 		if field.Size > 0 && field.Size < 65536 { | ||||
| 			return fmt.Sprintf("varbinary(%d)", field.Size) | ||||
| 		} | ||||
| 
 | ||||
| 		if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { | ||||
| 			return "mediumblob" | ||||
| 		} | ||||
| 
 | ||||
| 		return "longblob" | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
| @ -1,139 +0,0 @@ | ||||
| package postgres | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| type Migrator struct { | ||||
| 	migrator.Migrator | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CurrentDatabase() (name string) { | ||||
| 	m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||
| 	for _, opt := range opts { | ||||
| 		str := stmt.Quote(opt.DBName) | ||||
| 		if opt.Expression != "" { | ||||
| 			str = opt.Expression | ||||
| 		} | ||||
| 
 | ||||
| 		if opt.Collate != "" { | ||||
| 			str += " COLLATE " + opt.Collate | ||||
| 		} | ||||
| 
 | ||||
| 		if opt.Sort != "" { | ||||
| 			str += " " + opt.Sort | ||||
| 		} | ||||
| 		results = append(results, clause.Expr{SQL: str}) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||
| 	var count int64 | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			name = idx.Name | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 
 | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CreateIndex(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			opts := m.BuildIndexOptions(idx.Fields, stmt) | ||||
| 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | ||||
| 
 | ||||
| 			createIndexSQL := "CREATE " | ||||
| 			if idx.Class != "" { | ||||
| 				createIndexSQL += idx.Class + " " | ||||
| 			} | ||||
| 			createIndexSQL += "INDEX ?" | ||||
| 
 | ||||
| 			if idx.Type != "" { | ||||
| 				createIndexSQL += " USING " + idx.Type | ||||
| 			} | ||||
| 			createIndexSQL += " ON ??" | ||||
| 
 | ||||
| 			if idx.Where != "" { | ||||
| 				createIndexSQL += " WHERE " + idx.Where | ||||
| 			} | ||||
| 
 | ||||
| 			return m.DB.Exec(createIndexSQL, values...).Error | ||||
| 		} | ||||
| 
 | ||||
| 		return fmt.Errorf("failed to create index with name %v", name) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		return m.DB.Exec( | ||||
| 			"ALTER INDEX ? RENAME TO ?", | ||||
| 			clause.Column{Name: oldName}, clause.Column{Name: newName}, | ||||
| 		).Error | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropIndex(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			name = idx.Name | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasTable(value interface{}) bool { | ||||
| 	var count int64 | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema =  CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) | ||||
| 	}) | ||||
| 
 | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropTable(values ...interface{}) error { | ||||
| 	values = m.ReorderModels(values, false) | ||||
| 	tx := m.DB.Session(&gorm.Session{}) | ||||
| 	for i := len(values) - 1; i >= 0; i-- { | ||||
| 		if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { | ||||
| 			return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error | ||||
| 		}); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasColumn(value interface{}, field string) bool { | ||||
| 	var count int64 | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		name := field | ||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||
| 			name = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", | ||||
| 			stmt.Table, name, | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 
 | ||||
| 	return count > 0 | ||||
| } | ||||
| @ -1,102 +0,0 @@ | ||||
| package postgres | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/callbacks" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	_ "github.com/lib/pq" | ||||
| ) | ||||
| 
 | ||||
| type Dialector struct { | ||||
| 	DSN string | ||||
| } | ||||
| 
 | ||||
| func Open(dsn string) gorm.Dialector { | ||||
| 	return &Dialector{DSN: dsn} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Name() string { | ||||
| 	return "postgres" | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ | ||||
| 		WithReturning: true, | ||||
| 	}) | ||||
| 	db.ConnPool, err = sql.Open("postgres", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:                          db, | ||||
| 		Dialector:                   dialector, | ||||
| 		CreateIndexAfterCreateTable: true, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { | ||||
| 	writer.WriteByte('$') | ||||
| 	writer.WriteString(strconv.Itoa(len(stmt.Vars))) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { | ||||
| 	writer.WriteByte('"') | ||||
| 	writer.WriteString(str) | ||||
| 	writer.WriteByte('"') | ||||
| } | ||||
| 
 | ||||
| var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
| 		return "boolean" | ||||
| 	case schema.Int, schema.Uint: | ||||
| 		if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { | ||||
| 			switch { | ||||
| 			case field.Size < 16: | ||||
| 				return "smallserial" | ||||
| 			case field.Size < 31: | ||||
| 				return "serial" | ||||
| 			default: | ||||
| 				return "bigserial" | ||||
| 			} | ||||
| 		} else { | ||||
| 			switch { | ||||
| 			case field.Size < 16: | ||||
| 				return "smallint" | ||||
| 			case field.Size < 31: | ||||
| 				return "integer" | ||||
| 			default: | ||||
| 				return "bigint" | ||||
| 			} | ||||
| 		} | ||||
| 	case schema.Float: | ||||
| 		return "decimal" | ||||
| 	case schema.String: | ||||
| 		if field.Size > 0 { | ||||
| 			return fmt.Sprintf("varchar(%d)", field.Size) | ||||
| 		} | ||||
| 		return "text" | ||||
| 	case schema.Time: | ||||
| 		return "timestamptz" | ||||
| 	case schema.Bytes: | ||||
| 		return "bytea" | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
| @ -1,211 +0,0 @@ | ||||
| package sqlite | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| type Migrator struct { | ||||
| 	migrator.Migrator | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasTable(value interface{}) bool { | ||||
| 	var count int | ||||
| 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) | ||||
| 	}) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasColumn(value interface{}, name string) bool { | ||||
| 	var count int | ||||
| 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 			name = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", | ||||
| 			"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) AlterColumn(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 			var ( | ||||
| 				createSQL    string | ||||
| 				newTableName = stmt.Table + "__temp" | ||||
| 			) | ||||
| 
 | ||||
| 			m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) | ||||
| 
 | ||||
| 			if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { | ||||
| 				tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) | ||||
| 				createSQL = reg.ReplaceAllString(createSQL, "?") | ||||
| 
 | ||||
| 				var columns []string | ||||
| 				columnTypes, _ := m.DB.Migrator().ColumnTypes(value) | ||||
| 				for _, columnType := range columnTypes { | ||||
| 					columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) | ||||
| 				} | ||||
| 
 | ||||
| 				createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) | ||||
| 				return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error | ||||
| 			} else { | ||||
| 				return err | ||||
| 			} | ||||
| 		} else { | ||||
| 			return fmt.Errorf("failed to alter field with name %v", name) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropColumn(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 			name = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		var ( | ||||
| 			createSQL    string | ||||
| 			newTableName = stmt.Table + "__temp" | ||||
| 		) | ||||
| 
 | ||||
| 		m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) | ||||
| 
 | ||||
| 		if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { | ||||
| 			tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) | ||||
| 			createSQL = reg.ReplaceAllString(createSQL, "") | ||||
| 
 | ||||
| 			var columns []string | ||||
| 			columnTypes, _ := m.DB.Migrator().ColumnTypes(value) | ||||
| 			for _, columnType := range columnTypes { | ||||
| 				if columnType.Name() != name { | ||||
| 					columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) | ||||
| 
 | ||||
| 			return m.DB.Exec(createSQL).Error | ||||
| 		} else { | ||||
| 			return err | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CreateConstraint(interface{}, string) error { | ||||
| 	return gorm.ErrNotImplemented | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropConstraint(interface{}, string) error { | ||||
| 	return gorm.ErrNotImplemented | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CurrentDatabase() (name string) { | ||||
| 	var null interface{} | ||||
| 	m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||
| 	for _, opt := range opts { | ||||
| 		str := stmt.Quote(opt.DBName) | ||||
| 		if opt.Expression != "" { | ||||
| 			str = opt.Expression | ||||
| 		} | ||||
| 
 | ||||
| 		if opt.Collate != "" { | ||||
| 			str += " COLLATE " + opt.Collate | ||||
| 		} | ||||
| 
 | ||||
| 		if opt.Sort != "" { | ||||
| 			str += " " + opt.Sort | ||||
| 		} | ||||
| 		results = append(results, clause.Expr{SQL: str}) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CreateIndex(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			opts := m.BuildIndexOptions(idx.Fields, stmt) | ||||
| 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | ||||
| 
 | ||||
| 			createIndexSQL := "CREATE " | ||||
| 			if idx.Class != "" { | ||||
| 				createIndexSQL += idx.Class + " " | ||||
| 			} | ||||
| 			createIndexSQL += "INDEX ?" | ||||
| 
 | ||||
| 			if idx.Type != "" { | ||||
| 				createIndexSQL += " USING " + idx.Type | ||||
| 			} | ||||
| 			createIndexSQL += " ON ??" | ||||
| 
 | ||||
| 			if idx.Where != "" { | ||||
| 				createIndexSQL += " WHERE " + idx.Where | ||||
| 			} | ||||
| 
 | ||||
| 			return m.DB.Exec(createIndexSQL, values...).Error | ||||
| 		} | ||||
| 
 | ||||
| 		return fmt.Errorf("failed to create index with name %v", name) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||
| 	var count int | ||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			name = idx.Name | ||||
| 		} | ||||
| 
 | ||||
| 		m.DB.Raw( | ||||
| 			"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, | ||||
| 		).Row().Scan(&count) | ||||
| 		return nil | ||||
| 	}) | ||||
| 	return count > 0 | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		var sql string | ||||
| 		m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) | ||||
| 		if sql != "" { | ||||
| 			return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error | ||||
| 		} | ||||
| 		return fmt.Errorf("failed to find index with name %v", oldName) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) DropIndex(value interface{}, name string) error { | ||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||
| 			name = idx.Name | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error | ||||
| 	}) | ||||
| } | ||||
| @ -1,80 +0,0 @@ | ||||
| package sqlite | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/callbacks" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
| 
 | ||||
| type Dialector struct { | ||||
| 	DSN string | ||||
| } | ||||
| 
 | ||||
| func Open(dsn string) gorm.Dialector { | ||||
| 	return &Dialector{DSN: dsn} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Name() string { | ||||
| 	return "sqlite" | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ | ||||
| 		LastInsertIDReversed: true, | ||||
| 	}) | ||||
| 	db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:                          db, | ||||
| 		Dialector:                   dialector, | ||||
| 		CreateIndexAfterCreateTable: true, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { | ||||
| 	writer.WriteByte('?') | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { | ||||
| 	writer.WriteByte('`') | ||||
| 	writer.WriteString(str) | ||||
| 	writer.WriteByte('`') | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, nil, `"`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
| 		return "numeric" | ||||
| 	case schema.Int, schema.Uint: | ||||
| 		if field.AutoIncrement { | ||||
| 			// https://www.sqlite.org/autoinc.html
 | ||||
| 			return "integer PRIMARY KEY AUTOINCREMENT" | ||||
| 		} else { | ||||
| 			return "integer" | ||||
| 		} | ||||
| 	case schema.Float: | ||||
| 		return "real" | ||||
| 	case schema.String: | ||||
| 		return "text" | ||||
| 	case schema.Time: | ||||
| 		return "datetime" | ||||
| 	case schema.Bytes: | ||||
| 		return "blob" | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							| @ -3,12 +3,6 @@ module gorm.io/gorm | ||||
| go 1.14 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc | ||||
| 	github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 | ||||
| 	github.com/go-sql-driver/mysql v1.5.0 | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
| 	github.com/jinzhu/now v1.1.1 | ||||
| 	github.com/lib/pq v1.1.1 | ||||
| 	github.com/mattn/go-sqlite3 v2.0.1+incompatible | ||||
| 	gorm.io/gorm v1.9.12 | ||||
| ) | ||||
|  | ||||
| @ -9,7 +9,7 @@ import ( | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestFieldValuerAndSetter(t *testing.T) { | ||||
|  | ||||
| @ -5,7 +5,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| type User struct { | ||||
|  | ||||
| @ -7,7 +7,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | ||||
|  | ||||
| @ -5,7 +5,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/tests" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestParseSchema(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestBelongsToAssociation(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestHasManyAssociation(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestHasOneAssociation(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestMany2ManyAssociation(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { | ||||
|  | ||||
| @ -4,7 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestCount(t *testing.T) { | ||||
|  | ||||
| @ -6,7 +6,7 @@ import ( | ||||
| 
 | ||||
| 	"github.com/jinzhu/now" | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestCreate(t *testing.T) { | ||||
|  | ||||
| @ -3,8 +3,6 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestCustomizeColumn(t *testing.T) { | ||||
|  | ||||
| @ -5,7 +5,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestDelete(t *testing.T) { | ||||
|  | ||||
| @ -4,7 +4,6 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestEmbeddedStruct(t *testing.T) { | ||||
|  | ||||
							
								
								
									
										14
									
								
								tests/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								tests/go.mod
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| module gorm.io/gorm/tests | ||||
| 
 | ||||
| go 1.14 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/jinzhu/now v1.1.1 | ||||
| 	gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 | ||||
| 	gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 | ||||
| 	gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 | ||||
| 	gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 | ||||
| 	gorm.io/gorm v1.9.12 | ||||
| ) | ||||
| 
 | ||||
| replace gorm.io/gorm => ../ | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestGroupBy(t *testing.T) { | ||||
|  | ||||
| @ -1,17 +1,13 @@ | ||||
| package tests | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"go/ast" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/utils" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| type Config struct { | ||||
| @ -73,101 +69,6 @@ func GetUser(name string, config Config) *User { | ||||
| 	return &user | ||||
| } | ||||
| 
 | ||||
| func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { | ||||
| 	for _, name := range names { | ||||
| 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||
| 		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			AssertEqual(t, got, expect) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func AssertEqual(t *testing.T, got, expect interface{}) { | ||||
| 	if !reflect.DeepEqual(got, expect) { | ||||
| 		isEqual := func() { | ||||
| 			if curTime, ok := got.(time.Time); ok { | ||||
| 				format := "2006-01-02T15:04:05Z07:00" | ||||
| 
 | ||||
| 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { | ||||
| 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) | ||||
| 				} | ||||
| 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | ||||
| 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if fmt.Sprint(got) == fmt.Sprint(expect) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { | ||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := got.(driver.Valuer); ok { | ||||
| 			got, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := expect.(driver.Valuer); ok { | ||||
| 			expect, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if got != nil { | ||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		if expect != nil { | ||||
| 			expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { | ||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Kind() == reflect.Slice { | ||||
| 			if reflect.ValueOf(expect).Kind() == reflect.Slice { | ||||
| 				if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { | ||||
| 					for i := 0; i < reflect.ValueOf(got).Len(); i++ { | ||||
| 						name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) | ||||
| 						t.Run(name, func(t *testing.T) { | ||||
| 							AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				} else { | ||||
| 					name := reflect.ValueOf(got).Type().Elem().Name() | ||||
| 					t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Kind() == reflect.Struct { | ||||
| 			if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { | ||||
| 				for i := 0; i < reflect.ValueOf(got).NumField(); i++ { | ||||
| 					if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { | ||||
| 						field := reflect.ValueOf(got).Field(i) | ||||
| 						t.Run(fieldStruct.Name, func(t *testing.T) { | ||||
| 							AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { | ||||
| 			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() | ||||
| 			isEqual() | ||||
| 		} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { | ||||
| 			expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() | ||||
| 			isEqual() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CheckPet(t *testing.T, pet Pet, expect Pet) { | ||||
| 	if pet.ID != 0 { | ||||
| 		var newPet Pet | ||||
| @ -6,7 +6,6 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| type Product struct { | ||||
|  | ||||
| @ -5,7 +5,6 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| type Person struct { | ||||
|  | ||||
| @ -5,7 +5,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestJoins(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestMain(m *testing.M) { | ||||
|  | ||||
| @ -7,7 +7,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestMigrate(t *testing.T) { | ||||
|  | ||||
| @ -4,8 +4,6 @@ import ( | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| type Blog struct { | ||||
| @ -36,8 +34,8 @@ func compareTags(tags []Tag, contents []string) bool { | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | ||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||
| 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | ||||
| @ -125,8 +123,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | ||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||
| 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | ||||
| @ -246,8 +244,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | ||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||
| 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| type Hamster struct { | ||||
|  | ||||
| @ -3,8 +3,6 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| type Animal struct { | ||||
|  | ||||
| @ -7,7 +7,6 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func toJSONString(v interface{}) []byte { | ||||
| @ -691,8 +690,8 @@ func TestNestedPreload12(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { | ||||
| 		t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") | ||||
| 	if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { | ||||
| 		t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") | ||||
| 	} | ||||
| 
 | ||||
| 	type ( | ||||
|  | ||||
| @ -6,7 +6,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestNestedPreload(t *testing.T) { | ||||
|  | ||||
| @ -9,7 +9,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestFind(t *testing.T) { | ||||
|  | ||||
| @ -6,7 +6,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestScan(t *testing.T) { | ||||
|  | ||||
| @ -11,7 +11,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestScannerValuer(t *testing.T) { | ||||
|  | ||||
| @ -4,7 +4,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func NameIn1And2(d *gorm.DB) *gorm.DB { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestSoftDelete(t *testing.T) { | ||||
|  | ||||
| @ -4,7 +4,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestRow(t *testing.T) { | ||||
|  | ||||
| @ -18,8 +18,13 @@ for dialect in "${dialects[@]}" ; do | ||||
|     if [ "$GORM_VERBOSE" = "" ] | ||||
|     then | ||||
|       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... | ||||
|       cd tests | ||||
|       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... | ||||
|     else | ||||
|       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... | ||||
|       cd tests | ||||
|       DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... | ||||
|     fi | ||||
|     cd .. | ||||
|   fi | ||||
| done | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| package tests | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"log" | ||||
| @ -7,12 +7,13 @@ import ( | ||||
| 	"path/filepath" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/driver/mysql" | ||||
| 	"gorm.io/driver/postgres" | ||||
| 	"gorm.io/driver/sqlite" | ||||
| 	"gorm.io/driver/sqlserver" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/dialects/mssql" | ||||
| 	"gorm.io/gorm/dialects/mysql" | ||||
| 	"gorm.io/gorm/dialects/postgres" | ||||
| 	"gorm.io/gorm/dialects/sqlite" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| var DB *gorm.DB | ||||
| @ -40,17 +41,17 @@ func OpenTestConnection() (db *gorm.DB, err error) { | ||||
| 			dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" | ||||
| 		} | ||||
| 		db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) | ||||
| 	case "mssql": | ||||
| 	case "sqlserver": | ||||
| 		// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
 | ||||
| 		// CREATE DATABASE gorm;
 | ||||
| 		// USE gorm;
 | ||||
| 		// CREATE USER gorm FROM LOGIN gorm;
 | ||||
| 		// sp_changedbowner 'gorm';
 | ||||
| 		log.Println("testing mssql...") | ||||
| 		log.Println("testing sqlserver...") | ||||
| 		if dbDSN == "" { | ||||
| 			dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" | ||||
| 		} | ||||
| 		db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) | ||||
| 		db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) | ||||
| 	default: | ||||
| 		log.Println("testing sqlite3...") | ||||
| 		db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) | ||||
| @ -90,8 +91,3 @@ func RunMigrations() { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Now() *time.Time { | ||||
| 	now := time.Now() | ||||
| 	return &now | ||||
| } | ||||
| @ -6,7 +6,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestTransaction(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestUpdateBelongsTo(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestUpdateHasManyAssociations(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestUpdateHasOne(t *testing.T) { | ||||
|  | ||||
| @ -3,7 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestUpdateMany2ManyAssociations(t *testing.T) { | ||||
|  | ||||
| @ -8,7 +8,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestUpdate(t *testing.T) { | ||||
|  | ||||
| @ -5,7 +5,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	. "gorm.io/gorm/tests" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestUpsert(t *testing.T) { | ||||
|  | ||||
							
								
								
									
										112
									
								
								utils/tests/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								utils/tests/utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,112 @@ | ||||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"go/ast" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { | ||||
| 	for _, name := range names { | ||||
| 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||
| 		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			AssertEqual(t, got, expect) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func AssertEqual(t *testing.T, got, expect interface{}) { | ||||
| 	if !reflect.DeepEqual(got, expect) { | ||||
| 		isEqual := func() { | ||||
| 			if curTime, ok := got.(time.Time); ok { | ||||
| 				format := "2006-01-02T15:04:05Z07:00" | ||||
| 
 | ||||
| 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { | ||||
| 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) | ||||
| 				} | ||||
| 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | ||||
| 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if fmt.Sprint(got) == fmt.Sprint(expect) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { | ||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := got.(driver.Valuer); ok { | ||||
| 			got, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := expect.(driver.Valuer); ok { | ||||
| 			expect, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if got != nil { | ||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		if expect != nil { | ||||
| 			expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { | ||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Kind() == reflect.Slice { | ||||
| 			if reflect.ValueOf(expect).Kind() == reflect.Slice { | ||||
| 				if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { | ||||
| 					for i := 0; i < reflect.ValueOf(got).Len(); i++ { | ||||
| 						name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) | ||||
| 						t.Run(name, func(t *testing.T) { | ||||
| 							AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				} else { | ||||
| 					name := reflect.ValueOf(got).Type().Elem().Name() | ||||
| 					t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Kind() == reflect.Struct { | ||||
| 			if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { | ||||
| 				for i := 0; i < reflect.ValueOf(got).NumField(); i++ { | ||||
| 					if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { | ||||
| 						field := reflect.ValueOf(got).Field(i) | ||||
| 						t.Run(fieldStruct.Name, func(t *testing.T) { | ||||
| 							AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { | ||||
| 			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() | ||||
| 			isEqual() | ||||
| 		} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { | ||||
| 			expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() | ||||
| 			isEqual() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Now() *time.Time { | ||||
| 	now := time.Now() | ||||
| 	return &now | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu