Merge branch 'master' into wrapper

This commit is contained in:
Jinzhu 2018-02-10 14:07:04 +08:00 committed by GitHub
commit 0a44593c9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 773 additions and 189 deletions

45
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,45 @@
Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one.
### What version of Go are you using (`go version`)?
### Which database and its version are you using?
### Please provide a complete runnable program to reproduce your issue. **IMPORTANT**
Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config.
```go
package main
import (
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
var db *gorm.DB
func init() {
var err error
db, err = gorm.Open("sqlite3", "test.db")
// db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True")
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm")
if err != nil {
panic(err)
}
db.LogMode(true)
}
func main() {
if /* failure condition */ {
fmt.Println("failed")
} else {
fmt.Println("success")
}
}
```

9
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,9 @@
Make sure these boxes checked before submitting your pull request.
- [] Do only one thing
- [] No API-breaking changes
- [] New code/logic commented & tested
For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
### What did this pull request do?

View File

@ -1,52 +0,0 @@
# How to Contribute
## Bug Report
- Do a search on GitHub under Issues in case it has already been reported
- Submit __executable script__ or failing test pull request that could demonstrates the issue is *MUST HAVE*
## Feature Request
- Feature request with pull request is welcome
- Or it won't be implemented until I (other developers) find it is helpful for my (their) daily work
## Pull Request
- Prefer single commit pull request, that make the git history can be a bit easier to follow.
- New features need to be covered with tests to make sure your code works as expected, and won't be broken by others in future
## Contributing to Documentation
- You are welcome ;)
- You can help improve the README by making them more coherent, consistent or readable, and add more godoc documents to make people easier to follow.
- Blogs & Usage Guides & PPT also welcome, please add them to https://github.com/jinzhu/gorm/wiki/Guides
### Executable script template
```go
package main
import (
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
var db *gorm.DB
func init() {
var err error
db, err = gorm.Open("sqlite3", "test.db")
// db, err = gorm.Open("postgres", "user=username dbname=password sslmode=disable")
// db, err = gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
if err != nil {
panic(err)
}
db.LogMode(true)
}
func main() {
// Your code
}
```

View File

@ -3,7 +3,8 @@
The fantastic ORM library for Golang, aims to be developer friendly.
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
## Overview
@ -31,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Supporting the project
[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu)
[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu)
## Author

View File

@ -290,7 +290,9 @@ func (association *Association) Count() int {
)
}
query.Model(fieldValue).Count(&count)
if err := query.Model(fieldValue).Count(&count).Error; err != nil {
association.Error = err
}
return count
}

View File

@ -1,8 +1,6 @@
package gorm
import (
"fmt"
)
import "log"
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
@ -95,7 +93,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
fmt.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
cp.before = "gorm:row_query"
}
}
@ -109,7 +107,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
@ -122,7 +120,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.processor = &callback
cp.replace = true
@ -161,7 +159,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
allNames = append(allNames, cp.name)
}

View File

@ -32,8 +32,18 @@ func beforeCreateCallback(scope *Scope) {
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
scope.SetColumn("CreatedAt", now)
scope.SetColumn("UpdatedAt", now)
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
if createdAtField.IsBlank {
createdAtField.Set(now)
}
}
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
if updatedAtField.IsBlank {
updatedAtField.Set(now)
}
}
}
}

View File

@ -15,6 +15,10 @@ func init() {
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
defer scope.trace(NowFunc())
var (
@ -80,9 +84,7 @@ func queryCallback(scope *Scope) {
if err := rows.Err(); err != nil {
scope.Err(err)
}
if scope.db.RowsAffected == 0 && !isSlice {
} else if scope.db.RowsAffected == 0 && !isSlice {
scope.Err(ErrRecordNotFound)
}
}

View File

@ -4,11 +4,20 @@ import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
)
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
if _, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope)
}
if scope.Search.preload == nil || scope.HasError() {
return
}
@ -79,6 +88,25 @@ func preloadCallback(scope *Scope) {
}
}
func autoPreload(scope *Scope) {
for _, field := range scope.Fields() {
if field.Relationship == nil {
continue
}
if val, ok := field.TagSettings["PRELOAD"]; ok {
if preload, err := strconv.ParseBool(val); err != nil {
scope.Err(errors.New("invalid preload option"))
return
} else if !preload {
continue
}
}
scope.Search.Preload(field.Name)
}
}
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
var (
preloadDB = scope.NewDB()
@ -264,7 +292,12 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value).Select("*")
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
if len(preloadDB.search.selects) == 0 {
preloadDB = preloadDB.Select("*")
}
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
// preload inline conditions
@ -294,6 +327,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
scope.scan(rows, columns, append(fields, joinTableFields...))
scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries)
var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {

View File

@ -5,6 +5,8 @@ import (
"reflect"
"testing"
"time"
"github.com/jinzhu/now"
)
func TestCreate(t *testing.T) {
@ -58,6 +60,34 @@ func TestCreate(t *testing.T) {
}
}
func TestCreateWithExistingTimestamp(t *testing.T) {
user := User{Name: "CreateUserExistingTimestamp"}
timeA := now.MustParse("2016-01-01")
user.CreatedAt = timeA
user.UpdatedAt = timeA
DB.Save(&user)
if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt should not be changed")
}
if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt should not be changed")
}
var newUser User
DB.First(&newUser, user.Id)
if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt should not be changed")
}
if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt should not be changed")
}
}
type AutoIncrementUser struct {
User
Sequence uint `gorm:"AUTO_INCREMENT"`

View File

@ -33,6 +33,8 @@ type Dialect interface {
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
@ -41,8 +43,8 @@ type Dialect interface {
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
BuildKeyName(kind, tableName string, fields ...string) string
// CurrentDatabase return current database name
CurrentDatabase() string
@ -114,3 +116,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@ -31,13 +31,20 @@ func (s *commonDialect) SetDB(db SQLCommon) {
}
func (commonDialect) BindVar(i int) string {
return "$$" // ?
return "$$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
return value != "FALSE"
}
return field.IsPrimaryKey
}
func (s *commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
@ -46,13 +53,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if s.fieldCanAutoIncrement(field) {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if s.fieldCanAutoIncrement(field) {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
@ -92,7 +99,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
@ -107,16 +115,23 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
@ -124,12 +139,12 @@ func (s commonDialect) CurrentDatabase() (name string) {
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
@ -144,8 +159,14 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s
return ""
}
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
return keyName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
@ -42,29 +43,43 @@ func (s *mysql) DataTypeOf(field *StructField) string {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
case reflect.Int8:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "tinyint AUTO_INCREMENT"
} else {
sqlType = "tinyint"
}
case reflect.Int, reflect.Int16, reflect.Int32:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
case reflect.Uint8:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "tinyint unsigned AUTO_INCREMENT"
} else {
sqlType = "tinyint unsigned"
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
@ -80,14 +95,19 @@ func (s *mysql) DataTypeOf(field *StructField) string {
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
precision := ""
if p, ok := field.TagSettings["PRECISION"]; ok {
precision = fmt.Sprintf("(%s)", p)
}
if _, ok := field.TagSettings["NOT NULL"]; ok {
sqlType = "timestamp"
sqlType = fmt.Sprintf("timestamp%v", precision)
} else {
sqlType = "timestamp NULL"
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
}
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
@ -112,9 +132,30 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error {
return err
}
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
}
}
return
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0
}
@ -127,8 +168,8 @@ func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
if utf8.RuneCountInString(keyName) <= 64 {
return keyName
}
@ -136,8 +177,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
h.Write([]byte(keyName))
bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
// sha1 is 40 characters, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
if len(destRunes) > 24 {
destRunes = destRunes[:24]
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"encoding/json"
"fmt"
"reflect"
"strings"
@ -13,6 +14,7 @@ type postgres struct {
func init() {
RegisterDialect("postgres", &postgres{})
RegisterDialect("cloudsqlpostgres", &postgres{})
}
func (postgres) GetName() string {
@ -30,15 +32,15 @@ func (s *postgres) DataTypeOf(field *StructField) string {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
case reflect.Int64, reflect.Uint32, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigserial"
} else {
@ -65,10 +67,16 @@ func (s *postgres) DataTypeOf(field *StructField) string {
sqlType = "hstore"
}
default:
if isByteArrayOrSlice(dataValue) {
if IsByteArrayOrSlice(dataValue) {
sqlType = "bytea"
} else if isUUID(dataValue) {
sqlType = "uuid"
if isUUID(dataValue) {
sqlType = "uuid"
}
if isJSON(dataValue) {
sqlType = "jsonb"
}
}
}
}
@ -85,7 +93,7 @@ func (s *postgres) DataTypeOf(field *StructField) string {
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
@ -97,13 +105,13 @@ func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
@ -120,10 +128,6 @@ func (postgres) SupportLastInsertID() bool {
return false
}
func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
@ -132,3 +136,8 @@ func isUUID(value reflect.Value) bool {
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}
func isJSON(value reflect.Value) bool {
_, ok := value.Interface().(json.RawMessage)
return ok
}

View File

@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "integer primary key autoincrement"
} else {
@ -54,7 +54,7 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
sqlType = "datetime"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if IsByteArrayOrSlice(dataValue) {
sqlType = "blob"
}
}

View File

@ -50,11 +50,11 @@ func (s *mssql) SetDB(db gorm.SQLCommon) {
}
func (mssql) BindVar(i int) string {
return "$$" // ?
return "$$$" // ?
}
func (mssql) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
return fmt.Sprintf(`[%s]`, key)
}
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
case reflect.Bool:
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint IDENTITY(1,1)"
} else {
@ -81,21 +81,21 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
case reflect.Float32, reflect.Float64:
sqlType = "float"
case reflect.String:
if size > 0 && size < 65532 {
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("nvarchar(%d)", size)
} else {
sqlType = "text"
sqlType = "nvarchar(max)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime2"
sqlType = "datetimeoffset"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
if gorm.IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "text"
sqlType = "varbinary(max)"
}
}
}
@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
return value != "FALSE"
}
return field.IsPrimaryKey
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
@ -128,16 +135,23 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
@ -145,12 +159,12 @@ func (s mssql) CurrentDatabase() (name string) {
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"
@ -168,3 +182,11 @@ func (mssql) SelectFromDummyTable() string {
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@ -6,6 +6,9 @@ import (
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
"encoding/json"
"errors"
"fmt"
)
type Hstore map[string]*string
@ -52,3 +55,26 @@ func (h *Hstore) Scan(value interface{}) error {
return nil
}
// Jsonb Postgresql's JSONB data type
type Jsonb struct {
json.RawMessage
}
// Value get value of Jsonb
func (j Jsonb) Value() (driver.Value, error) {
if len(j.RawMessage) == 0 {
return nil, nil
}
return j.MarshalJSON()
}
// Scan scan value into Jsonb
func (j *Jsonb) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
return json.Unmarshal(bytes, j)
}

30
docker-compose.yml Normal file
View File

@ -0,0 +1,30 @@
version: '3'
services:
mysql:
image: 'mysql:latest'
ports:
- 9910:3306
environment:
- MYSQL_DATABASE=gorm
- MYSQL_USER=gorm
- MYSQL_PASSWORD=gorm
- MYSQL_RANDOM_ROOT_PASSWORD="yes"
postgres:
image: 'postgres:latest'
ports:
- 9920:5432
environment:
- POSTGRES_USER=gorm
- POSTGRES_DB=gorm
- POSTGRES_PASSWORD=gorm
mssql:
image: 'mcmoe/mssqldocker:latest'
ports:
- 9930:1433
environment:
- ACCEPT_EULA=Y
- SA_PASSWORD=LoremIpsum86
- MSSQL_DB=gorm
- MSSQL_USER=gorm
- MSSQL_PASSWORD=LoremIpsum86

View File

@ -29,6 +29,10 @@ func (errs Errors) GetErrors() []error {
// Add adds an error
func (errs Errors) Add(newErrors ...error) Errors {
for _, err := range newErrors {
if err == nil {
continue
}
if errors, ok := err.(Errors); ok {
errs = errs.Add(errors...)
} else {

View File

@ -59,6 +59,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
s.Source.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.ForeignFieldNames {
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: relationship.ForeignDBNames[idx],
@ -67,6 +68,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
}
s.Destination = JoinTableSource{ModelType: destination}
s.Destination.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.AssociationForeignFieldNames {
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: relationship.AssociationForeignDBNames[idx],

View File

@ -80,3 +80,38 @@ func TestJoinTable(t *testing.T) {
t.Errorf("Should deleted all addresses")
}
}
func TestEmbeddedMany2ManyRelationship(t *testing.T) {
type EmbeddedPerson struct {
ID int
Name string
Addresses []*Address `gorm:"many2many:person_addresses;"`
}
type NewPerson struct {
EmbeddedPerson
ExternalID uint
}
DB.Exec("drop table person_addresses;")
DB.AutoMigrate(&NewPerson{})
address1 := &Address{Address1: "address 1"}
address2 := &Address{Address1: "address 2"}
person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}}
if err := DB.Save(person).Error; err != nil {
t.Errorf("no error should return when save embedded many2many relationship, but got %v", err)
}
if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil {
t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err)
}
association := DB.Model(person).Association("Addresses")
if count := association.Count(); count != 1 || association.Error != nil {
t.Errorf("Should found one address, but got %v, error is %v", count, association.Error)
}
if association.Clear(); association.Count() != 0 {
t.Errorf("Should deleted all addresses")
}
}

View File

@ -7,6 +7,7 @@ import (
"os"
"reflect"
"regexp"
"strconv"
"time"
"unicode"
)
@ -87,6 +88,7 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
}
messages = append(messages, sql)
messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned "))
} else {
messages = append(messages, "\033[31;1m")
messages = append(messages, values[2:]...)

45
main.go
View File

@ -71,11 +71,13 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
dialect: newDialect(dialect, dbSQL),
}
db.parent = db
if err == nil {
// Send a ping to make sure the database connection is alive.
if err = db.DB().Ping(); err != nil {
db.DB().Close()
if err != nil {
return
}
// Send a ping to make sure the database connection is alive.
if d, ok := dbSQL.(*sql.DB); ok {
if err = d.Ping(); err != nil {
d.Close()
}
}
return
@ -166,6 +168,15 @@ func (s *DB) NewScope(value interface{}) *Scope {
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// QueryExpr returns the query as expr object
func (s *DB) QueryExpr() *expr {
scope := s.NewScope(s.Value)
scope.InstanceSet("skip_bindvar", true)
scope.prepareQuerySQL()
return Expr(scope.SQL, scope.SQLVars...)
}
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db
@ -216,7 +227,7 @@ func (s *DB) Group(query string) *DB {
}
// Having specify HAVING conditions for GROUP BY
func (s *DB) Having(query string, values ...interface{}) *DB {
func (s *DB) Having(query interface{}, values ...interface{}) *DB {
return s.clone().search.Having(query, values...).db
}
@ -452,7 +463,7 @@ func (s *DB) Debug() *DB {
// Begin begin a transaction
func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok {
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
@ -464,7 +475,7 @@ func (s *DB) Begin() *DB {
// Commit commit a transaction
func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.AddError(db.Commit())
} else {
s.AddError(ErrInvalidTransaction)
@ -474,7 +485,7 @@ func (s *DB) Commit() *DB {
// Rollback rollback a transaction
func (s *DB) Rollback() *DB {
if db, ok := s.db.(sqlTx); ok {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.AddError(db.Rollback())
} else {
s.AddError(ErrInvalidTransaction)
@ -600,6 +611,14 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate
return scope.db
}
// RemoveForeignKey Remove foreign key from the given scope, e.g:
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
scope := s.clone().NewScope(s.Value)
scope.removeForeignKey(field, dest)
return scope.db
}
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
func (s *DB) Association(column string) *Association {
var err error
@ -700,7 +719,7 @@ func (s *DB) GetErrors() []error {
////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB {
db := DB{
db := &DB{
db: s.db,
parent: s.parent,
logger: s.logger,
@ -721,8 +740,8 @@ func (s *DB) clone() *DB {
db.search = s.search.clone()
}
db.search.db = &db
return &db
db.search.db = db
return db
}
func (s *DB) print(v ...interface{}) {
@ -737,6 +756,6 @@ func (s *DB) log(v ...interface{}) {
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
if s.logMode == 2 {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
}
}

View File

@ -36,27 +36,20 @@ func init() {
}
func OpenTestConnection() (db *gorm.DB, err error) {
dbDSN := os.Getenv("GORM_DSN")
switch os.Getenv("GORM_DIALECT") {
case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
// CREATE DATABASE gorm;
// GRANT ALL ON gorm.* TO 'gorm'@'localhost';
fmt.Println("testing mysql...")
dbhost := os.Getenv("GORM_DBADDRESS")
if dbhost != "" {
dbhost = fmt.Sprintf("tcp(%v)", dbhost)
if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
}
db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost))
db, err = gorm.Open("mysql", dbDSN)
case "postgres":
fmt.Println("testing postgres...")
dbhost := os.Getenv("GORM_DBHOST")
if dbhost != "" {
dbhost = fmt.Sprintf("host=%v ", dbhost)
if dbDSN == "" {
dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
}
db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost))
case "foundation":
fmt.Println("testing foundation...")
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
db, err = gorm.Open("postgres", dbDSN)
case "mssql":
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
// CREATE DATABASE gorm;
@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
// CREATE USER gorm FROM LOGIN gorm;
// sp_changedbowner 'gorm';
fmt.Println("testing mssql...")
db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
if dbDSN == "" {
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
}
db, err = gorm.Open("mssql", dbDSN)
default:
fmt.Println("testing sqlite3...")
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
@ -72,8 +68,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
// db.SetLogger(log.New(os.Stdout, "\r\n", 0))
if os.Getenv("DEBUG") == "true" {
if debug := os.Getenv("DEBUG"); debug == "true" {
db.LogMode(true)
} else if debug == "false" {
db.LogMode(false)
}
db.DB().SetMaxIdleConns(10)
@ -607,9 +605,54 @@ func TestHaving(t *testing.T) {
}
}
func TestQueryBuilderSubselectInWhere(t *testing.T) {
user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32}
DB.Save(&user)
user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16}
DB.Save(&user)
user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64}
DB.Save(&user)
user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128}
DB.Save(&user)
var users []User
DB.Select("*").Where("name IN (?)", DB.
Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
if len(users) != 4 {
t.Errorf("Four users should be found, instead found %d", len(users))
}
DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB.
Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
if len(users) != 2 {
t.Errorf("Two users should be found, instead found %d", len(users))
}
}
func TestQueryBuilderSubselectInHaving(t *testing.T) {
user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64}
DB.Save(&user)
user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128}
DB.Save(&user)
user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64}
DB.Save(&user)
user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128}
DB.Save(&user)
var users []User
DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB.
Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users)
if len(users) != 1 {
t.Errorf("Two user group should be found, instead found %d", len(users))
}
}
func DialectHasTzSupport() bool {
// NB: mssql and FoundationDB do not support time zones.
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" {
return false
}
return true

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"os"
"reflect"
"testing"
"time"
@ -432,3 +433,24 @@ func TestMultipleIndexes(t *testing.T) {
t.Error("MultipleIndexes unique index failed")
}
}
func TestModifyColumnType(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
}
type ModifyColumnType struct {
gorm.Model
Name1 string `gorm:"length:100"`
Name2 string `gorm:"length:200"`
}
DB.DropTable(&ModifyColumnType{})
DB.CreateTable(&ModifyColumnType{})
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
}
}

View File

@ -97,7 +97,11 @@ func (structField *StructField) clone() *StructField {
TagSettings: map[string]string{},
Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey,
Relationship: structField.Relationship,
}
if structField.Relationship != nil {
relationship := *structField.Relationship
clone.Relationship = &relationship
}
for key, value := range structField.TagSettings {
@ -219,6 +223,15 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
subField.IsPrimaryKey = false
}
}
if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
newJoinTableHandler := &JoinTableHandler{}
newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
subField.Relationship.JoinTableHandler = newJoinTableHandler
}
}
modelStruct.StructFields = append(modelStruct.StructFields, subField)
}
continue

View File

@ -96,6 +96,33 @@ func TestPreload(t *testing.T) {
}
}
func TestAutoPreload(t *testing.T) {
user1 := getPreloadUser("auto_user1")
DB.Save(user1)
preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload")
var user User
preloadDB.Find(&user)
checkUserHasPreloadData(user, t)
user2 := getPreloadUser("auto_user2")
DB.Save(user2)
var users []User
preloadDB.Find(&users)
for _, user := range users {
checkUserHasPreloadData(user, t)
}
var users2 []*User
preloadDB.Find(&users2)
for _, user := range users2 {
checkUserHasPreloadData(*user, t)
}
}
func TestNestedPreload1(t *testing.T) {
type (
Level1 struct {
@ -1600,6 +1627,48 @@ func TestPrefixedPreloadDuplication(t *testing.T) {
}
}
func TestPreloadManyToManyCallbacks(t *testing.T) {
type (
Level2 struct {
ID uint
Name string
}
Level1 struct {
ID uint
Name string
Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"`
}
)
DB.DropTableIfExists("level1_level2s")
DB.DropTableIfExists(new(Level1))
DB.DropTableIfExists(new(Level2))
if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil {
t.Error(err)
}
lvl := Level1{
Name: "l1",
Level2s: []Level2{
Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
},
}
DB.Save(&lvl)
called := 0
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
called = called + 1
})
DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)
if called != 3 {
t.Errorf("Wanted callback to be called 3 times but got %d", called)
}
}
func toJSONString(v interface{}) []byte {
r, _ := json.MarshalIndent(v, "", " ")
return r

View File

@ -222,7 +222,7 @@ func TestSearchWithStruct(t *testing.T) {
}
DB.First(&user, User{Name: user1.Name})
if user.Id == 0 || user.Name != user.Name {
if user.Id == 0 || user.Name != user1.Name {
t.Errorf("Search first record with inline struct")
}

View File

@ -253,15 +253,25 @@ func (scope *Scope) CallMethod(methodName string) {
// AddToVars add value as sql's vars, used to prevent SQL injection
func (scope *Scope) AddToVars(value interface{}) string {
_, skipBindVar := scope.InstanceGet("skip_bindvar")
if expr, ok := value.(*expr); ok {
exp := expr.expr
for _, arg := range expr.args {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
if skipBindVar {
scope.AddToVars(arg)
} else {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
}
}
return exp
}
scope.SQLVars = append(scope.SQLVars, value)
if skipBindVar {
return "?"
}
return scope.Dialect().BindVar(len(scope.SQLVars))
}
@ -329,18 +339,18 @@ func (scope *Scope) QuotedTableName() (name string) {
// CombinedConditionSql return combined condition sql
func (scope *Scope) CombinedConditionSql() string {
joinSql := scope.joinsSQL()
whereSql := scope.whereSQL()
joinSQL := scope.joinsSQL()
whereSQL := scope.whereSQL()
if scope.Search.raw {
whereSql = strings.TrimSuffix(strings.TrimPrefix(whereSql, "WHERE ("), ")")
whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")
}
return joinSql + whereSql + scope.groupSQL() +
return joinSQL + whereSQL + scope.groupSQL() +
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
}
// Raw set raw sql
func (scope *Scope) Raw(sql string) *Scope {
scope.SQL = strings.Replace(sql, "$$", "?", -1)
scope.SQL = strings.Replace(sql, "$$$", "?", -1)
return scope
}
@ -448,8 +458,8 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
}
var (
columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name`
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name`
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ")
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
)
@ -1129,7 +1139,7 @@ func (scope *Scope) dropTable() *Scope {
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
}
func (scope *Scope) dropColumn(column string) {
@ -1155,7 +1165,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
// Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
@ -1164,6 +1175,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
}
func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
}
@ -1199,7 +1220,7 @@ func (scope *Scope) autoIndex() *Scope {
for _, name := range names {
if name == "INDEX" || name == "" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
}
indexes[name] = append(indexes[name], field.DBName)
}
@ -1210,7 +1231,7 @@ func (scope *Scope) autoIndex() *Scope {
for _, name := range names {
if name == "UNIQUE_INDEX" || name == "" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
}
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
}
@ -1218,11 +1239,15 @@ func (scope *Scope) autoIndex() *Scope {
}
for name, columns := range indexes {
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
scope.db.AddError(db.Error)
}
}
for name, columns := range uniqueIndexes {
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
scope.db.AddError(db.Error)
}
}
return scope

View File

@ -2,7 +2,6 @@ package gorm
import (
"fmt"
"regexp"
)
type search struct {
@ -73,13 +72,7 @@ func (s *search) Order(value interface{}, reorder ...bool) *search {
return s
}
var distinctSQLRegexp = regexp.MustCompile(`(?i)distinct[^a-z]+[a-z]+`)
func (s *search) Select(query interface{}, args ...interface{}) *search {
if distinctSQLRegexp.MatchString(fmt.Sprint(query)) {
s.ignoreOrderQuery = true
}
s.selects = map[string]interface{}{"query": query, "args": args}
return s
}
@ -104,8 +97,12 @@ func (s *search) Group(query string) *search {
return s
}
func (s *search) Having(query string, values ...interface{}) *search {
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
func (s *search) Having(query interface{}, values ...interface{}) *search {
if val, ok := query.(*expr); ok {
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
} else {
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
}
return s
}

View File

@ -1,5 +1,5 @@
dialects=("postgres" "mysql" "sqlite" "mssql")
dialects=("postgres" "mysql" "mssql" "sqlite")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test
DEBUG=false GORM_DIALECT=${dialect} go test
done

View File

@ -23,7 +23,7 @@ var NowFunc = func() time.Time {
}
// Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
@ -97,6 +97,9 @@ func ToDBName(name string) string {
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && nextCase == upper {
buf.WriteRune('_')
}
}
} else {
currCase = upper

View File

@ -9,11 +9,13 @@ import (
func TestToDBNameGenerateFriendlyName(t *testing.T) {
var maps = map[string]string{
"": "",
"X": "x",
"ThisIsATest": "this_is_a_test",
"PFAndESI": "pf_and_esi",
"AbcAndJkl": "abc_and_jkl",
"EmployeeID": "employee_id",
"SKU_ID": "sku_id",
"FieldX": "field_x",
"HTTPAndSMTP": "http_and_smtp",
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
"UUID": "uuid",

View File

@ -2,17 +2,79 @@
box: golang
services:
- id: mariadb:10.0
- name: mariadb
id: mariadb:latest
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- id: postgres
- name: mysql
id: mysql:8
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: mysql57
id: mysql:5.7
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: mysql56
id: mysql:5.6
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: mysql55
id: mysql:5.5
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: postgres
id: postgres:latest
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres96
id: postgres:9.6
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres95
id: postgres:9.5
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres94
id: postgres:9.4
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres93
id: postgres:9.3
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: mssql
id: mcmoe/mssqldocker:latest
env:
ACCEPT_EULA: Y
SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
# The steps that will be executed in the build pipeline
build:
@ -42,12 +104,57 @@ build:
code: |
go test ./...
- script:
name: test mariadb
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql
code: |
GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./...
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql5.7
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql5.6
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql5.5
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test postgres
code: |
GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./...
GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres96
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres95
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres94
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres93
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test mssql
code: |
GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...