diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 8b4f03b7..a0b64bfa 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,10 +1,4 @@ -Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`. - -DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm, - -Please answer these questions before submitting your issue. Thanks! - - +Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. ### What version of Go are you using (`go version`)? @@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks! ### Which database and its version are you using? -### What did you do? +### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** -Please provide a complete runnable program to reproduce your issue. +Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. ```go package main @@ -32,10 +26,9 @@ var db *gorm.DB func init() { var err error db, err = gorm.Open("sqlite3", "test.db") - // Please use below username, password as your database's account for the script. - // db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") + // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") + // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") if err != nil { panic(err) } @@ -43,8 +36,6 @@ func init() { } func main() { - // your code here - if /* failure condition */ { fmt.Println("failed") } else { diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0ee0d73b..b467b6ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request. - [] Do only one thing - [] No API-breaking changes - [] New code/logic commented & tested -- [] Write good commit message, try to squash your commits into a single one -- [] Run `./build.sh` in `gh-pages` branch for document changes - -For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. - -Thank you. +For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. ### What did this pull request do? diff --git a/README.md b/README.md index 44eb4a69..8c6e2302 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) +[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview @@ -31,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Supporting the project -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) +[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu) ## Author diff --git a/callback_query.go b/callback_query.go index 20e88161..ba10cc7d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,6 +15,10 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } + defer scope.trace(NowFunc()) var ( diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce..30f6b585 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -10,6 +10,9 @@ import ( // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } if _, ok := scope.Get("gorm:auto_preload"); ok { autoPreload(scope) @@ -324,6 +327,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) + scope.New(elem.Addr().Interface()). + InstanceSet("gorm:skip_query_callback", true). + callCallbacks(scope.db.parent.callbacks.queries) + var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { diff --git a/dialect.go b/dialect.go index e879588b..fe8e2f62 100644 --- a/dialect.go +++ b/dialect.go @@ -33,6 +33,8 @@ type Dialect interface { HasTable(tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool + // ModifyColumn modify column's type + ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) string @@ -41,8 +43,8 @@ type Dialect interface { // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string - // BuildForeignKeyName returns a foreign key name for the given table, field and reference - BuildForeignKeyName(tableName, field, dest string) string + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference + BuildKeyName(kind, tableName string, fields ...string) string // CurrentDatabase return current database name CurrentDatabase() string @@ -114,3 +116,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialect_common.go b/dialect_common.go index a99627f2..1e5e3b61 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -38,6 +38,13 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + func (s *commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) @@ -46,13 +53,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "INTEGER AUTO_INCREMENT" } else { sqlType = "INTEGER" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "BIGINT AUTO_INCREMENT" } else { sqlType = "BIGINT" @@ -92,7 +99,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -107,16 +115,23 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } +func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) + return err +} + func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return @@ -144,9 +159,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } -func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { - keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") +// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference +func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) + keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..1feed1f6 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint unsigned AUTO_INCREMENT" } else { @@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string { } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { + precision := "" + if p, ok := field.TagSettings["PRECISION"]; ok { + precision = fmt.Sprintf("(%s)", p) + } + if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" + sqlType = fmt.Sprintf("timestamp%v", precision) } else { - sqlType = "timestamp NULL" + sqlType = fmt.Sprintf("timestamp%v NULL", precision) } } default: @@ -127,6 +132,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { return err } +func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { @@ -144,7 +154,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } @@ -157,8 +168,8 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { - keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) +func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) if utf8.RuneCountInString(keyName) <= 64 { return keyName } @@ -166,8 +177,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { h.Write([]byte(keyName)) bs := h.Sum(nil) - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..c44c6a5b 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "encoding/json" "fmt" "reflect" "strings" @@ -32,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigserial" } else { @@ -68,9 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { default: if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" + if isUUID(dataValue) { sqlType = "uuid" } + + if isJSON(dataValue) { + sqlType = "jsonb" + } } } } @@ -130,3 +136,8 @@ func isUUID(value reflect.Value) bool { lower := strings.ToLower(typename) return "uuid" == lower || "guid" == lower } + +func isJSON(value reflect.Value) bool { + _, ok := value.Interface().(json.RawMessage) + return ok +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..f26f6be3 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..1dd5fb69 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string { } func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) + return fmt.Sprintf(`[%s]`, key) } func (s *mssql) DataTypeOf(field *gorm.StructField) string { @@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { case reflect.Bool: sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint IDENTITY(1,1)" } else { @@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } +func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + func (s mssql) HasIndex(tableName string, indexName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) @@ -128,16 +135,23 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } func (s mssql) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mssql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return @@ -168,3 +182,11 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index b8e76891..1d0dcb60 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -63,6 +63,9 @@ type Jsonb struct { // Value get value of Jsonb func (j Jsonb) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } return j.MarshalJSON() } diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..79bf5fc3 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/main.go b/main.go index 16fa0b79..b23ae2f2 100644 --- a/main.go +++ b/main.go @@ -611,6 +611,14 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate return scope.db } +// RemoveForeignKey Remove foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +func (s *DB) RemoveForeignKey(field string, dest string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeForeignKey(field, dest) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error diff --git a/main_test.go b/main_test.go index 34f96a86..83e6f7aa 100644 --- a/main_test.go +++ b/main_test.go @@ -36,27 +36,20 @@ func init() { } func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) + db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + db, err = gorm.Open("postgres", dbDSN) case "mssql": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; @@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open("mssql", dbDSN) default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) @@ -72,8 +68,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { + if debug := os.Getenv("DEBUG"); debug == "true" { db.LogMode(true) + } else if debug == "false" { + db.LogMode(false) } db.DB().SetMaxIdleConns(10) diff --git a/migration_test.go b/migration_test.go index 3f3a5c8f..6b4470a6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -435,10 +435,7 @@ func TestMultipleIndexes(t *testing.T) { } func TestModifyColumnType(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect != "postgres" && - dialect != "mysql" && - dialect != "mssql" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") } diff --git a/preload_test.go b/preload_test.go index 1b89e77b..311ad0be 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1627,6 +1627,48 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } } +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` + } + ) + + DB.DropTableIfExists("level1_level2s") + DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level2)) + + if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { + t.Error(err) + } + + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, + }, + } + DB.Save(&lvl) + + called := 0 + + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + called = called + 1 + }) + + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r diff --git a/scope.go b/scope.go index 0f2dbda1..4b3e99f8 100644 --- a/scope.go +++ b/scope.go @@ -1158,7 +1158,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { @@ -1184,7 +1184,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + // Compatible with old generated key + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return @@ -1193,6 +1194,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +func (scope *Scope) removeForeignKey(field string, dest string) { + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) + + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() +} + func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } @@ -1228,7 +1239,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) } indexes[name] = append(indexes[name], field.DBName) } @@ -1239,7 +1250,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) } uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } @@ -1247,11 +1258,15 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) + } } for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) + } } return scope diff --git a/test_all.sh b/test_all.sh index 80b319bf..5cfb3321 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,5 +1,5 @@ dialects=("postgres" "mysql" "mssql" "sqlite") for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test + DEBUG=false GORM_DIALECT=${dialect} go test done diff --git a/utils.go b/utils.go index 97a3d175..dfaae939 100644 --- a/utils.go +++ b/utils.go @@ -23,7 +23,7 @@ var NowFunc = func() time.Time { } // Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) diff --git a/wercker.yml b/wercker.yml index ff6fb17c..2f2370b3 100644 --- a/wercker.yml +++ b/wercker.yml @@ -2,17 +2,79 @@ box: golang services: - - id: mariadb:10.0 + - name: mariadb + id: mariadb:latest env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres + - name: mysql + id: mysql:8 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql55 + id: mysql:5.5 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest env: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm + - name: postgres96 + id: postgres:9.6 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres95 + id: postgres:9.5 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres94 + id: postgres:9.4 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres93 + id: postgres:9.3 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 # The steps that will be executed in the build pipeline build: @@ -42,12 +104,57 @@ build: code: | go test ./... + - script: + name: test mariadb + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + - script: name: test mysql code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.7 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.5 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres96 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres95 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres94 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres93 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test mssql + code: | + GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...