From e77fbed442c252dfadd77b1534260389597bdd4a Mon Sep 17 00:00:00 2001 From: deoxxa Date: Mon, 16 Mar 2015 11:22:31 +1100 Subject: [PATCH 01/17] scanner.Scan() can fail, so the error should be forwarded --- field.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/field.go b/field.go index e122adb4..8f5efa6d 100644 --- a/field.go +++ b/field.go @@ -27,9 +27,13 @@ func (field *Field) Set(value interface{}) error { if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { if v, ok := value.(reflect.Value); ok { - scanner.Scan(v.Interface()) + if err := scanner.Scan(v.Interface()); err != nil { + return err + } } else { - scanner.Scan(value) + if err := scanner.Scan(value); err != nil { + return err + } } } else { reflectValue, ok := value.(reflect.Value) From a0848909c256ecc2d877090f933a8b6120cf2fb0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Mar 2015 10:40:42 +0800 Subject: [PATCH 02/17] Simplify dialect definitions --- common_dialect.go | 32 +++++++++++++------------- mssql.go | 38 +++++++------------------------ mysql.go | 58 +++++------------------------------------------ postgres.go | 29 ++++++++---------------- sqlite3.go | 36 ++++++----------------------- 5 files changed, 46 insertions(+), 147 deletions(-) diff --git a/common_dialect.go b/common_dialect.go index 9360cd26..281df8a7 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -9,19 +9,19 @@ import ( type commonDialect struct{} -func (s *commonDialect) BinVar(i int) string { - return "?" +func (commonDialect) BinVar(i int) string { + return "$$" // ? } -func (s *commonDialect) SupportLastInsertId() bool { +func (commonDialect) SupportLastInsertId() bool { return true } -func (s *commonDialect) HasTop() bool { +func (commonDialect) HasTop() bool { return false } -func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "BOOLEAN" @@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) } -func (s *commonDialect) ReturningStr(tableName, key string) string { +func (commonDialect) ReturningStr(tableName, key string) string { return "" } -func (s *commonDialect) SelectFromDummyTable() string { +func (commonDialect) SelectFromDummyTable() string { return "" } -func (s *commonDialect) Quote(key string) string { - return fmt.Sprintf("`%s`", key) +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) } -func (s *commonDialect) databaseName(scope *Scope) string { +func (commonDialect) databaseName(scope *Scope) string { from := strings.Index(scope.db.parent.source, "/") + 1 to := strings.Index(scope.db.parent.source, "?") if to == -1 { @@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string { return scope.db.parent.source[from:to] } -func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { +func (c commonDialect) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count) return count > 0 } -func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) return count > 0 } -func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) { +func (commonDialect) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/mssql.go b/mssql.go index dc8e2917..c44541c7 100644 --- a/mssql.go +++ b/mssql.go @@ -7,21 +7,15 @@ import ( "time" ) -type mssql struct{} - -func (s *mssql) BinVar(i int) string { - return "$$" // ? +type mssql struct { + commonDialect } -func (s *mssql) SupportLastInsertId() bool { +func (mssql) HasTop() bool { return true } -func (s *mssql) HasTop() bool { - return true -} - -func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bit" @@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) } -func (s *mssql) ReturningStr(tableName, key string) string { - return "" -} - -func (s *mssql) SelectFromDummyTable() string { - return "" -} - -func (s *mssql) Quote(key string) string { - return fmt.Sprintf(" \"%s\"", key) -} - -func (s *mssql) databaseName(scope *Scope) string { +func (mssql) databaseName(scope *Scope) string { dbStr := strings.Split(scope.db.parent.source, ";") for _, value := range dbStr { s := strings.Split(value, "=") @@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string { return "" } -func (s *mssql) HasTable(scope *Scope, tableName string) bool { +func (s mssql) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } -func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) return count > 0 } - -func (s *mssql) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) -} diff --git a/mysql.go b/mysql.go index d2eb08a5..e37a23e0 100644 --- a/mysql.go +++ b/mysql.go @@ -3,25 +3,14 @@ package gorm import ( "fmt" "reflect" - "strings" "time" ) -type mysql struct{} - -func (s *mysql) BinVar(i int) string { - return "$$" // ? +type mysql struct { + commonDialect } -func (s *mysql) SupportLastInsertId() bool { - return true -} - -func (s *mysql) HasTop() bool { - return false -} - -func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) ReturningStr(tableName, key string) string { - return "" -} - -func (s *mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s *mysql) Quote(key string) string { +func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (s *mysql) databaseName(scope *Scope) string { - from := strings.Index(scope.db.parent.source, "/") + 1 - to := strings.Index(scope.db.parent.source, "?") - if to == -1 { - to = len(scope.db.parent.source) - } - return scope.db.parent.source[from:to] -} - -func (s *mysql) HasTable(scope *Scope, tableName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) +func (mysql) SelectFromDummyTable() string { + return "FROM DUAL" } diff --git a/postgres.go b/postgres.go index 83c37e1f..4218e1ba 100644 --- a/postgres.go +++ b/postgres.go @@ -11,21 +11,18 @@ import ( ) type postgres struct { + commonDialect } -func (s *postgres) BinVar(i int) string { +func (postgres) BinVar(i int) string { return fmt.Sprintf("$%v", i) } -func (s *postgres) SupportLastInsertId() bool { +func (postgres) SupportLastInsertId() bool { return false } -func (s *postgres) HasTop() bool { - return false -} - -func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) ReturningStr(tableName, key string) string { +func (s postgres) ReturningStr(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) } -func (s *postgres) SelectFromDummyTable() string { - return "" -} - -func (s *postgres) Quote(key string) string { - return fmt.Sprintf("\"%s\"", key) -} - -func (s *postgres) HasTable(scope *Scope, tableName string) bool { +func (postgres) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count) return count > 0 } -func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *postgres) RemoveIndex(scope *Scope, indexName string) { +func (postgres) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } -func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count) return count > 0 diff --git a/sqlite3.go b/sqlite3.go index ce71ee08..afe70e3a 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -6,21 +6,11 @@ import ( "time" ) -type sqlite3 struct{} - -func (s *sqlite3) BinVar(i int) string { - return "$$" // ? +type sqlite3 struct { + commonDialect } -func (s *sqlite3) SupportLastInsertId() bool { - return true -} - -func (s *sqlite3) HasTop() bool { - return false -} - -func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bool" @@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) ReturningStr(tableName, key string) string { - return "" -} - -func (s *sqlite3) SelectFromDummyTable() string { - return "" -} - -func (s *sqlite3) Quote(key string) string { - return fmt.Sprintf("\"%s\"", key) -} - -func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { +func (sqlite3) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) { +func (sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } From 87ff58b598c607fa260295732bc00e67b3272894 Mon Sep 17 00:00:00 2001 From: Will Glynn Date: Tue, 17 Mar 2015 09:04:12 -0500 Subject: [PATCH 03/17] Fix creation of composite unique indexes --- scope_private.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope_private.go b/scope_private.go index e51e1faf..6d700cb9 100644 --- a/scope_private.go +++ b/scope_private.go @@ -570,7 +570,7 @@ func (scope *Scope) autoIndex() *Scope { if name == "UNIQUE_INDEX" { name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) } - uniqueIndexes[name] = append(indexes[name], field.DBName) + uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } } From c13e2f18f8d6f571e4ab4229dbc782cccc2f4125 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Mar 2015 11:47:11 +0800 Subject: [PATCH 04/17] New JoinTableHandler --- association.go | 8 ++--- callback_shared.go | 7 ++-- join_table.go | 86 ++++++++++++++++++++++++++++++++-------------- main.go | 27 --------------- model_struct.go | 4 +-- scope_private.go | 22 ++++-------- 6 files changed, 76 insertions(+), 78 deletions(-) diff --git a/association.go b/association.go index b011971a..60763f8c 100644 --- a/association.go +++ b/association.go @@ -77,7 +77,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -132,7 +132,7 @@ func (association *Association) Replace(values ...interface{}) *Association { sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) } else { association.setErr(errors.New("replace only support many to many")) } @@ -145,7 +145,7 @@ func (association *Association) Clear() *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey) - if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -165,7 +165,7 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) - scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count) + relationship.JoinTableHandler.JoinWith(query, association.Scope.Value).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/callback_shared.go b/callback_shared.go index 99ad8f50..ae75c250 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -38,7 +38,7 @@ func SaveAfterAssociations(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTable == "" && relationship.ForeignFieldName != "" { + if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) } @@ -48,9 +48,8 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) - if joinTable := relationship.JoinTable; joinTable != "" { - scope.Err(scope.db.GetJoinTableHandler(joinTable). - Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue())) + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value)) } } default: diff --git a/join_table.go b/join_table.go index 3ffa4f87..b6b33d9a 100644 --- a/join_table.go +++ b/join_table.go @@ -5,44 +5,78 @@ import ( "strings" ) -type JoinTableHandler interface { - Table(*DB, *Relationship) string - Add(*DB, *Relationship, interface{}, interface{}) error - Delete(*DB, *Relationship) error - Scope(*DB, *Relationship) *DB +type JoinTableHandlerInterface interface { + Table(db *DB) string + Add(db *DB, source1 interface{}, source2 interface{}) error + Delete(db *DB, sources ...interface{}) error + JoinWith(db *DB, source interface{}) *DB } -type defaultJoinTableHandler struct{} - -func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string { - return relationship.JoinTable +type JoinTableSource struct { + ForeignKey string + ForeignKeyPrefix string + ModelStruct } -func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { +type JoinTableHandler struct { + TableName string + Source1 JoinTableSource + Source2 JoinTableSource +} + +func (jt JoinTableHandler) Table(*DB) string { + return jt.TableName +} + +func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} { + values := map[string]interface{}{} + for _, source := range sources { + scope := db.NewScope(source) + for _, primaryField := range scope.GetModelStruct().PrimaryFields { + if field, ok := scope.Fields()[primaryField.DBName]; ok { + values[primaryField.DBName] = field.Field.Interface() + } + } + } + return values +} + +func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") - quotedForeignDBName := scope.Quote(relationship.ForeignDBName) - quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName) - table := s.Table(db, relationship) + valueMap := jt.GetValueMap(db, source1, source2) + var setColumns, setBinVars, queryConditions []string + var values []interface{} + for key, value := range valueMap { + setColumns = append(setColumns, key) + setBinVars = append(setBinVars, `?`) + queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + values = append(values, value) + } + + for _, value := range valueMap { + values = append(values, value) + } + + quotedTable := jt.Table(db) sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);", - scope.Quote(table), - strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","), + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", + quotedTable, + strings.Join(setColumns, ","), + strings.Join(setBinVars, ","), scope.Dialect().SelectFromDummyTable(), - scope.Quote(table), - quotedForeignDBName, - quotedAssociationDBName, + quotedTable, + strings.Join(queryConditions, " AND "), ) - return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error + return db.Exec(sql, values...).Error } -func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { - return db.Table(s.Table(db, relationship)).Delete("").Error +func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error { + // return db.Table(jt.Table(db)).Delete("").Error + return nil } -func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { - return db.Table(s.Table(db, relationship)) +func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB { + return db } - -var DefaultJoinTableHandler = &defaultJoinTableHandler{} diff --git a/main.go b/main.go index 87fecd59..377ef582 100644 --- a/main.go +++ b/main.go @@ -469,30 +469,3 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { value, ok = s.values[name] return } - -func (s *DB) GetJoinTableHandler(table string) JoinTableHandler { - if s.parent.joinTableHandlers != nil { - if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok { - return joinTableHandler - } - if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok { - return joinTableHandler - } - } - return DefaultJoinTableHandler -} - -func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) { - if s.parent.joinTableHandlers == nil { - s.parent.joinTableHandlers = map[string]JoinTableHandler{} - } - - if len(tables) > 0 { - for _, table := range tables { - s.parent.joinTableHandlers[table] = joinTableHandler - s.Table(table).AutoMigrate(joinTableHandler) - } - } else { - s.parent.joinTableHandlers["*"] = joinTableHandler - } -} diff --git a/model_struct.go b/model_struct.go index 02e23dfe..cce28330 100644 --- a/model_struct.go +++ b/model_struct.go @@ -60,7 +60,7 @@ type Relationship struct { ForeignDBName string AssociationForeignFieldName string AssociationForeignDBName string - JoinTable string + JoinTableHandler JoinTableHandlerInterface } var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} @@ -205,7 +205,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := gormSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - relationship.JoinTable = many2many + relationship.JoinTableHandler = JoinTableHandler{} associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] if associationForeignKey == "" { diff --git a/scope_private.go b/scope_private.go index 6d700cb9..5755a60c 100644 --- a/scope_private.go +++ b/scope_private.go @@ -402,18 +402,11 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if fromField != nil { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { - joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) - quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship)) - - joinSql := fmt.Sprintf( - "INNER JOIN %v ON %v.%v = %v.%v", - quotedJoinTable, - quotedJoinTable, - scope.Quote(relationship.AssociationForeignDBName), - toScope.QuotedTableName(), - scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)) - scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error) + joinTableHandler := relationship.JoinTableHandler + quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db)) + scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value). + Where(fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)), scope.PrimaryKeyValue()). + Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() @@ -443,9 +436,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" { - joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) - joinTable := joinTableHandler.Table(scope.db, relationship) + if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { + joinTable := relationship.JoinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", From 6ba0c1661f356dea7e2c97aa0bda83362275a797 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Mar 2015 18:14:28 +0800 Subject: [PATCH 05/17] Refactor JoinTableHandler --- join_table.go | 103 ++++++++++++++++++++++++++++++++------------- join_table_test.go | 9 ++-- scope_private.go | 5 ++- 3 files changed, 80 insertions(+), 37 deletions(-) diff --git a/join_table.go b/join_table.go index b6b33d9a..2aeb1c4a 100644 --- a/join_table.go +++ b/join_table.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "reflect" "strings" ) @@ -13,70 +14,114 @@ type JoinTableHandlerInterface interface { } type JoinTableSource struct { - ForeignKey string - ForeignKeyPrefix string - ModelStruct + ModelType reflect.Type + ForeignKeys []struct { + DBName string + AssociationDBName string + } } type JoinTableHandler struct { - TableName string - Source1 JoinTableSource - Source2 JoinTableSource + TableName string `sql:"-"` + Source JoinTableSource `sql:"-"` + Destination JoinTableSource `sql:"-"` } -func (jt JoinTableHandler) Table(*DB) string { - return jt.TableName +func (s JoinTableHandler) Table(*DB) string { + return s.TableName } -func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} { +func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { values := map[string]interface{}{} + for _, source := range sources { scope := db.NewScope(source) - for _, primaryField := range scope.GetModelStruct().PrimaryFields { - if field, ok := scope.Fields()[primaryField.DBName]; ok { - values[primaryField.DBName] = field.Field.Interface() + modelType := scope.GetModelStruct().ModelType + + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Source.ForeignKeys { + values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + } + } else if s.Destination.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() } } } return values } -func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { +func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") - valueMap := jt.GetValueMap(db, source1, source2) + searchMap := s.GetSearchMap(db, source1, source2) - var setColumns, setBinVars, queryConditions []string + var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range valueMap { - setColumns = append(setColumns, key) - setBinVars = append(setBinVars, `?`) - queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + for key, value := range searchMap { + assignColumns = append(assignColumns, key) + binVars = append(binVars, `?`) + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } - for _, value := range valueMap { + for _, value := range searchMap { values = append(values, value) } - quotedTable := jt.Table(db) + quotedTable := s.Table(db) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", quotedTable, - strings.Join(setColumns, ","), - strings.Join(setBinVars, ","), + strings.Join(assignColumns, ","), + strings.Join(binVars, ","), scope.Dialect().SelectFromDummyTable(), quotedTable, - strings.Join(queryConditions, " AND "), + strings.Join(conditions, " AND "), ) return db.Exec(sql, values...).Error } -func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error { - // return db.Table(jt.Table(db)).Delete("").Error - return nil +func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { + var conditions []string + var values []interface{} + + for key, value := range s.GetSearchMap(db, sources...) { + conditions = append(conditions, fmt.Sprintf("%v = ?", key)) + values = append(values, value) + } + + return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } -func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB { - return db +func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { + quotedTable := s.Table(db) + + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + var joinConditions []string + var queryConditions []string + var values []interface{} + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Source.ForeignKeys { + queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) + values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) + } + } else if s.Destination.ModelType == modelType { + for _, foreignKey := range s.Source.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Destination.ForeignKeys { + queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) + values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) + } + } + + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) } diff --git a/join_table_test.go b/join_table_test.go index 2624fdb2..429e46e1 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,16 +15,13 @@ type Person struct { } type PersonAddress struct { + gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string { - return relationship.JoinTable -} - func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ relationship.ForeignDBName: foreignValue, @@ -41,14 +38,14 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error } func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { - table := pa.Table(db, relationship) + table := pa.Table(db) return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { DB.Exec("drop table person_addresses;") DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") + // DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") address1 := &Address{Address1: "address 1"} address2 := &Address{Address1: "address 2"} diff --git a/scope_private.go b/scope_private.go index 5755a60c..da78d3f2 100644 --- a/scope_private.go +++ b/scope_private.go @@ -437,7 +437,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTable := relationship.JoinTableHandler.Table(scope.db) + joinTableHandler := relationship.JoinTableHandler + joinTable := joinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", @@ -447,7 +448,7 @@ func (scope *Scope) createJoinTable(field *StructField) { scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), ).Error) } - scope.NewDB().Table(joinTable).AutoMigrate() + scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } } From fa753969b1e5b8b2a8519135ef7eb8d21783cb41 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Mar 2015 15:02:15 +0800 Subject: [PATCH 06/17] Fix stack overflow --- join_table.go | 32 +++++- model_struct.go | 261 +++++++++++++++++++++++++----------------------- 2 files changed, 162 insertions(+), 131 deletions(-) diff --git a/join_table.go b/join_table.go index 2aeb1c4a..9b23f89f 100644 --- a/join_table.go +++ b/join_table.go @@ -13,12 +13,36 @@ type JoinTableHandlerInterface interface { JoinWith(db *DB, source interface{}) *DB } +type JoinTableForeignKey struct { + DBName string + AssociationDBName string +} + +func updateJoinTableHandler(relationship *Relationship) { + handler := relationship.JoinTableHandler.(*JoinTableHandler) + + destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + db := relationship.ForeignDBName + handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } + + sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + db := relationship.AssociationForeignDBName + handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } +} + type JoinTableSource struct { ModelType reflect.Type - ForeignKeys []struct { - DBName string - AssociationDBName string - } + ForeignKeys []JoinTableForeignKey } type JoinTableHandler struct { diff --git a/model_struct.go b/model_struct.go index cce28330..50940472 100644 --- a/model_struct.go +++ b/model_struct.go @@ -146,143 +146,150 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - for _, field := range fields { - if !field.IsIgnored { - fieldStruct := field.Struct - fieldType, indirectType := fieldStruct.Type, fieldStruct.Type - if indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { - field.IsScanner, field.IsNormal = true, true - } - - if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { - field.IsNormal = true - } - - if !field.IsNormal { - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) - - getForeignField := func(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == ToDBName(column) { - return field - } - } - return nil + defer func() { + for _, field := range fields { + if !field.IsIgnored { + fieldStruct := field.Struct + fieldType, indirectType := fieldStruct.Type, fieldStruct.Type + if indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() } - var relationship = &Relationship{} - - foreignKey := gormSettings["FOREIGNKEY"] - if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldName = polymorphicField.Name - relationship.ForeignDBName = polymorphicField.DBName - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } - } + if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { + field.IsScanner, field.IsNormal = true, true } - switch indirectType.Kind() { - case reflect.Slice: - elemType := indirectType.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } - - if many2many := gormSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" - relationship.JoinTableHandler = JoinTableHandler{} - - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] - if associationForeignKey == "" { - associationForeignKey = elemType.Name() + "Id" - } - - relationship.ForeignFieldName = foreignKey - relationship.ForeignDBName = ToDBName(foreignKey) - relationship.AssociationForeignFieldName = associationForeignKey - relationship.AssociationForeignDBName = ToDBName(associationForeignKey) - field.Relationship = relationship - } else { - relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, toField := range toScope.GetStructFields() { - toField = toField.clone() - toField.Names = append([]string{fieldStruct.Name}, toField.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, toField) - if toField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) - } - } - continue - } else { - belongsToForeignKey := foreignKey - if belongsToForeignKey == "" { - belongsToForeignKey = field.Name + "Id" - } - - if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { - relationship.Kind = "belongs_to" - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else { - if foreignKey == "" { - foreignKey = modelStruct.ModelType.Name() + "Id" - } - relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { - field.Relationship = relationship - } - } - } - default: + if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { field.IsNormal = true } - } - if field.IsNormal { - if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + if !field.IsNormal { + gormSettings := parseTagSetting(field.Tag.Get("gorm")) + toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) + + getForeignField := func(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == ToDBName(column) { + return field + } + } + return nil + } + + var relationship = &Relationship{} + + foreignKey := gormSettings["FOREIGNKEY"] + if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { + relationship.ForeignFieldName = polymorphicField.Name + relationship.ForeignDBName = polymorphicField.DBName + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + polymorphicField.IsForeignKey = true + } + } + } + + switch indirectType.Kind() { + case reflect.Slice: + elemType := indirectType.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + + if elemType.Kind() == reflect.Struct { + if foreignKey == "" { + foreignKey = scopeType.Name() + "Id" + } + + if many2many := gormSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" + associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] + if associationForeignKey == "" { + associationForeignKey = elemType.Name() + "Id" + } + + relationship.ForeignFieldName = foreignKey + relationship.ForeignDBName = ToDBName(foreignKey) + relationship.AssociationForeignFieldName = associationForeignKey + relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + relationship.JoinTableHandler = &JoinTableHandler{ + TableName: many2many, + Source: JoinTableSource{ModelType: scopeType}, + Destination: JoinTableSource{ModelType: elemType}, + } + updateJoinTableHandler(relationship) + + field.Relationship = relationship + } else { + relationship.Kind = "has_many" + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } + } + } else { + field.IsNormal = true + } + case reflect.Struct: + if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + for _, toField := range toScope.GetStructFields() { + toField = toField.clone() + toField.Names = append([]string{fieldStruct.Name}, toField.Names...) + modelStruct.StructFields = append(modelStruct.StructFields, toField) + if toField.IsPrimaryKey { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) + } + } + continue + } else { + belongsToForeignKey := foreignKey + if belongsToForeignKey == "" { + belongsToForeignKey = field.Name + "Id" + } + + if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { + relationship.Kind = "belongs_to" + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else { + if foreignKey == "" { + foreignKey = modelStruct.ModelType.Name() + "Id" + } + relationship.Kind = "has_one" + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } + } + } + default: + field.IsNormal = true + } + } + + if field.IsNormal { + if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } } } + modelStruct.StructFields = append(modelStruct.StructFields, field) } - modelStruct.StructFields = append(modelStruct.StructFields, field) - } + }() modelStructs[scopeType] = &modelStruct From 36efd0a561dc8b158149fae7cfa21a1998136425 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Mar 2015 16:42:13 +0800 Subject: [PATCH 07/17] Fix JoinTableHandler JoinWith --- association_test.go | 2 +- join_table.go | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/association_test.go b/association_test.go index 3ffd8880..a7b8f136 100644 --- a/association_test.go +++ b/association_test.go @@ -143,7 +143,7 @@ func TestManyToMany(t *testing.T) { // Query var newLanguages []Language - DB.Model(&user).Related(&newLanguages, "Languages") + DB.Debug().Model(&user).Related(&newLanguages, "Languages") if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Query many to many relations") } diff --git a/join_table.go b/join_table.go index 9b23f89f..d29f7bfb 100644 --- a/join_table.go +++ b/join_table.go @@ -128,7 +128,8 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { var values []interface{} if s.Source.ModelType == modelType { for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName))) + destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } for _, foreignKey := range s.Source.ForeignKeys { @@ -137,7 +138,8 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { } } else if s.Destination.ModelType == modelType { for _, foreignKey := range s.Source.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName))) + sourceTableName := scope.New(reflect.New(s.Source.ModelType).Interface()).QuotedTableName() + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), sourceTableName, scope.Quote(foreignKey.AssociationDBName))) } for _, foreignKey := range s.Destination.ForeignKeys { @@ -146,6 +148,6 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { } } - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", strings.Join(joinConditions, " AND "))). + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). Where(strings.Join(queryConditions, " AND "), values...) } From 44b106c8e26ab7b0320d489cfc66cf9db0ed1dc9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Mar 2015 18:23:54 +0800 Subject: [PATCH 08/17] Fix tests --- association.go | 4 +--- association_test.go | 2 +- join_table.go | 46 ++++++++++++++++++++++----------------------- join_table_test.go | 25 +++++++++++++----------- main.go | 9 +++++++++ scope_private.go | 5 +---- 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/association.go b/association.go index 60763f8c..89bb1bec 100644 --- a/association.go +++ b/association.go @@ -163,9 +163,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). - Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) - relationship.JoinTableHandler.JoinWith(query, association.Scope.Value).Count(&count) + relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/association_test.go b/association_test.go index a7b8f136..3ffd8880 100644 --- a/association_test.go +++ b/association_test.go @@ -143,7 +143,7 @@ func TestManyToMany(t *testing.T) { // Query var newLanguages []Language - DB.Debug().Model(&user).Related(&newLanguages, "Languages") + DB.Model(&user).Related(&newLanguages, "Languages") if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Query many to many relations") } diff --git a/join_table.go b/join_table.go index d29f7bfb..163bb4e2 100644 --- a/join_table.go +++ b/join_table.go @@ -18,28 +18,6 @@ type JoinTableForeignKey struct { AssociationDBName string } -func updateJoinTableHandler(relationship *Relationship) { - handler := relationship.JoinTableHandler.(*JoinTableHandler) - - destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - db := relationship.ForeignDBName - handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: db, - AssociationDBName: primaryField.DBName, - }) - } - - sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { - db := relationship.AssociationForeignDBName - handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ - DBName: db, - AssociationDBName: primaryField.DBName, - }) - } -} - type JoinTableSource struct { ModelType reflect.Type ForeignKeys []JoinTableForeignKey @@ -51,6 +29,28 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } +func updateJoinTableHandler(relationship *Relationship) { + handler := relationship.JoinTableHandler.(*JoinTableHandler) + + destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + db := relationship.AssociationForeignDBName + handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } + + sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + db := relationship.ForeignDBName + handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } +} + func (s JoinTableHandler) Table(*DB) string { return s.TableName } @@ -88,7 +88,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) values = append(values, value) } - for _, value := range searchMap { + for _, value := range values { values = append(values, value) } diff --git a/join_table_test.go b/join_table_test.go index 429e46e1..38e9f943 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,29 +15,32 @@ type Person struct { } type PersonAddress struct { - gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Table(db *gorm.DB) string { + return "person_addresses" +} + +func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ - relationship.ForeignDBName: foreignValue, - relationship.AssociationForeignDBName: associationValue, + "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), + "address_id": db.NewScope(associationValue).PrimaryKeyValue(), }).Assign(map[string]interface{}{ - relationship.ForeignFieldName: foreignValue, - relationship.AssociationForeignFieldName: associationValue, - "DeletedAt": gorm.Expr("NULL"), + "person_id": foreignValue, + "address_id": associationValue, + "DeletedAt": gorm.Expr("NULL"), }).FirstOrCreate(&PersonAddress{}).Error } -func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error { +func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { return db.Delete(&PersonAddress{}).Error } -func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { +func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } @@ -45,7 +48,7 @@ func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *go func TestJoinTable(t *testing.T) { DB.Exec("drop table person_addresses;") DB.AutoMigrate(&Person{}) - // DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") + DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) address1 := &Address{Address1: "address 1"} address2 := &Address{Address1: "address 2"} @@ -58,7 +61,7 @@ func TestJoinTable(t *testing.T) { t.Errorf("Should found one address") } - if DB.Model(person).Association("Addresses").Count() != 1 { + if DB.Debug().Model(person).Association("Addresses").Count() != 1 { t.Errorf("Should found one address") } diff --git a/main.go b/main.go index 377ef582..5f7db05a 100644 --- a/main.go +++ b/main.go @@ -469,3 +469,12 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { value, ok = s.values[name] return } + +func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { + for _, field := range s.NewScope(source).GetModelStruct().StructFields { + if field.Name == column || field.DBName == column { + field.Relationship.JoinTableHandler = handler + s.Table(handler.Table(s)).AutoMigrate(handler) + } + } +} diff --git a/scope_private.go b/scope_private.go index da78d3f2..d1f4a10b 100644 --- a/scope_private.go +++ b/scope_private.go @@ -403,10 +403,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db)) - scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value). - Where(fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)), scope.PrimaryKeyValue()). - Find(value).Error) + scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() From 9af056349670e100e76a572f31adb681205fd8b6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Mar 2015 18:30:35 +0800 Subject: [PATCH 09/17] Passed all tests --- join_table_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/join_table_test.go b/join_table_test.go index 38e9f943..40f36799 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -42,7 +42,7 @@ func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) - return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) + return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { @@ -61,7 +61,7 @@ func TestJoinTable(t *testing.T) { t.Errorf("Should found one address") } - if DB.Debug().Model(person).Association("Addresses").Count() != 1 { + if DB.Model(person).Association("Addresses").Count() != 1 { t.Errorf("Should found one address") } From 94a5ebe5b4011fd4f2064472b2e1f08a0d85646c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 20 Mar 2015 11:11:30 +0800 Subject: [PATCH 10/17] Refactor JoinTableHandler --- join_table.go => join_table_handler.go | 44 ++++++++++++-------------- join_table_test.go | 7 ++-- main.go | 9 ++++-- model_struct.go | 9 ++---- 4 files changed, 32 insertions(+), 37 deletions(-) rename join_table.go => join_table_handler.go (74%) diff --git a/join_table.go b/join_table_handler.go similarity index 74% rename from join_table.go rename to join_table_handler.go index 163bb4e2..21e88fe1 100644 --- a/join_table.go +++ b/join_table_handler.go @@ -1,14 +1,16 @@ package gorm import ( + "errors" "fmt" "reflect" "strings" ) type JoinTableHandlerInterface interface { + Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Table(db *DB) string - Add(db *DB, source1 interface{}, source2 interface{}) error + Add(db *DB, source interface{}, destination interface{}) error Delete(db *DB, sources ...interface{}) error JoinWith(db *DB, source interface{}) *DB } @@ -29,22 +31,24 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } -func updateJoinTableHandler(relationship *Relationship) { - handler := relationship.JoinTableHandler.(*JoinTableHandler) +func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { + s.TableName = tableName - destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - db := relationship.AssociationForeignDBName - handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ + s.Source = JoinTableSource{ModelType: source} + sourceScope := &Scope{Value: reflect.New(source).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + db := relationship.ForeignDBName + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ DBName: db, AssociationDBName: primaryField.DBName, }) } - sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { - db := relationship.ForeignDBName - handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ + s.Destination = JoinTableSource{ModelType: destination} + destinationScope := &Scope{Value: reflect.New(destination).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + db := relationship.AssociationForeignDBName + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ DBName: db, AssociationDBName: primaryField.DBName, }) @@ -136,18 +140,10 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - sourceTableName := scope.New(reflect.New(s.Source.ModelType).Interface()).QuotedTableName() - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), sourceTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - for _, foreignKey := range s.Destination.ForeignKeys { - queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) - values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) - } + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) + } else { + db.Error = errors.New("wrong source type for join table handler") + return db } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). - Where(strings.Join(queryConditions, " AND "), values...) } diff --git a/join_table_test.go b/join_table_test.go index 40f36799..f8b097b6 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,16 +15,13 @@ type Person struct { } type PersonAddress struct { + gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Table(db *gorm.DB) string { - return "person_addresses" -} - func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), @@ -32,7 +29,7 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu }).Assign(map[string]interface{}{ "person_id": foreignValue, "address_id": associationValue, - "DeletedAt": gorm.Expr("NULL"), + "deleted_at": gorm.Expr("NULL"), }).FirstOrCreate(&PersonAddress{}).Error } diff --git a/main.go b/main.go index 5f7db05a..b66ceda3 100644 --- a/main.go +++ b/main.go @@ -473,8 +473,13 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { for _, field := range s.NewScope(source).GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - field.Relationship.JoinTableHandler = handler - s.Table(handler.Table(s)).AutoMigrate(handler) + if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { + source := (&Scope{Value: source}).GetModelStruct().ModelType + destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType + handler.Setup(field.Relationship, many2many, source, destination) + field.Relationship.JoinTableHandler = handler + s.Table(handler.Table(s)).AutoMigrate(handler) + } } } } diff --git a/model_struct.go b/model_struct.go index 50940472..b7c44414 100644 --- a/model_struct.go +++ b/model_struct.go @@ -215,13 +215,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.ForeignDBName = ToDBName(foreignKey) relationship.AssociationForeignFieldName = associationForeignKey relationship.AssociationForeignDBName = ToDBName(associationForeignKey) - relationship.JoinTableHandler = &JoinTableHandler{ - TableName: many2many, - Source: JoinTableSource{ModelType: scopeType}, - Destination: JoinTableSource{ModelType: elemType}, - } - updateJoinTableHandler(relationship) + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, many2many, scopeType, elemType) + relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { relationship.Kind = "has_many" From 4b98b145b11d3da5a7562d89ca070134bd49b2b3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 20 Mar 2015 11:49:16 +0800 Subject: [PATCH 11/17] Fix foreign db name in join table for multi primary keys relations --- join_table_handler.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/join_table_handler.go b/join_table_handler.go index 21e88fe1..b4299f5a 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -37,9 +37,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.Source = JoinTableSource{ModelType: source} sourceScope := &Scope{Value: reflect.New(source).Interface()} for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { - db := relationship.ForeignDBName + if relationship.ForeignDBName == "" { + relationship.ForeignFieldName = source.Name() + primaryField.Name + relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) + } s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: db, + DBName: relationship.ForeignDBName, AssociationDBName: primaryField.DBName, }) } @@ -47,9 +50,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.Destination = JoinTableSource{ModelType: destination} destinationScope := &Scope{Value: reflect.New(destination).Interface()} for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - db := relationship.AssociationForeignDBName + if relationship.AssociationForeignDBName == "" { + relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name + relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName) + } s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: db, + DBName: relationship.AssociationForeignDBName, AssociationDBName: primaryField.DBName, }) } From 1e28551d25f09dafecd7f6561f6bfa6651352085 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 20 Mar 2015 17:21:13 +0800 Subject: [PATCH 12/17] Fix additional SQL type --- model_struct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_struct.go b/model_struct.go index b7c44414..f73c902b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -312,7 +312,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string { additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] if value, ok := sqlSettings["DEFAULT"]; ok { - additionalType = additionalType + "DEFAULT " + value + additionalType = additionalType + " DEFAULT " + value } if field.IsScanner { From 7d16055a5d33f44df201c310764fedeb0bd6cde4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 23 Mar 2015 11:07:39 +0800 Subject: [PATCH 13/17] Don't use instance setting for order_by_primary_key --- callback_query.go | 2 +- main.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/callback_query.go b/callback_query.go index 4538b272..5daa5fec 100644 --- a/callback_query.go +++ b/callback_query.go @@ -21,7 +21,7 @@ func Query(scope *Scope) { dest = reflect.Indirect(reflect.ValueOf(value)) } - if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok { + if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) } diff --git a/main.go b/main.go index b66ceda3..e197a99c 100644 --- a/main.go +++ b/main.go @@ -193,14 +193,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) - return newScope.InstanceSet("gorm:order_by_primary_key", "ASC"). + return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) - return newScope.InstanceSet("gorm:order_by_primary_key", "DESC"). + return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } From 8389d92f783142e7b15c3b2a32614e16a105d74b Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Mon, 23 Mar 2015 15:11:41 -0700 Subject: [PATCH 14/17] Futher clarified error messaging for invalid `plucks'. --- scope_private.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scope_private.go b/scope_private.go index d1f4a10b..8de12ced 100644 --- a/scope_private.go +++ b/scope_private.go @@ -3,7 +3,6 @@ package gorm import ( "database/sql" "database/sql/driver" - "errors" "fmt" "reflect" "regexp" @@ -360,7 +359,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) scope.Search.Select(column) if dest.Kind() != reflect.Slice { - scope.Err(errors.New("results should be a slice")) + scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope } From dcc06e22f7b54f4c38e347b25ba827845f92ea08 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Tue, 24 Mar 2015 10:33:51 -0700 Subject: [PATCH 15/17] FoundationDB dialect layer and compatibility updates. --- README.md | 1 + dialect.go | 2 ++ foundation.go | 78 +++++++++++++++++++++++++++++++++++++++++++ join_table_handler.go | 2 +- main.go | 3 ++ main_test.go | 31 ++++++++++++----- test_all.sh | 2 +- 7 files changed, 109 insertions(+), 10 deletions(-) create mode 100644 foundation.go diff --git a/README.md b/README.md index a6dd1865..50c4c5fe 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ import ( ) db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") diff --git a/dialect.go b/dialect.go index 2e64cca5..f3221075 100644 --- a/dialect.go +++ b/dialect.go @@ -24,6 +24,8 @@ func NewDialect(driver string) Dialect { switch driver { case "postgres": d = &postgres{} + case "foundation": + d = &foundation{} case "mysql": d = &mysql{} case "sqlite3": diff --git a/foundation.go b/foundation.go new file mode 100644 index 00000000..a9c8f500 --- /dev/null +++ b/foundation.go @@ -0,0 +1,78 @@ +package gorm + +import ( + "fmt" + "reflect" + "time" +) + +type foundation struct { + commonDialect +} + +func (foundation) BinVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (foundation) SupportLastInsertId() bool { + return false +} + +func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string { + switch value.Kind() { + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "serial" + } + return "int" + case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "bigserial" + } + return "bigint" + case reflect.Float32, reflect.Float64: + return "double" + case reflect.String: + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } + return "clob" + case reflect.Struct: + if _, ok := value.Interface().(time.Time); ok { + return "datetime" + } + default: + if _, ok := value.Interface().([]byte); ok { + return "blob" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) +} + +func (f foundation) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key) +} + +func (foundation) HasTable(scope *Scope, tableName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count) + return count > 0 +} + +func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + return count > 0 +} + +func (f foundation) RemoveIndex(scope *Scope, indexName string) { + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName))) +} + +func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) + return count > 0 +} diff --git a/join_table_handler.go b/join_table_handler.go index b4299f5a..9f705564 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -104,7 +104,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) quotedTable := s.Table(db) sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", quotedTable, strings.Join(assignColumns, ","), strings.Join(binVars, ","), diff --git a/main.go b/main.go index e197a99c..82567971 100644 --- a/main.go +++ b/main.go @@ -55,6 +55,9 @@ func Open(dialect string, args ...interface{}) (DB, error) { driver = value source = args[1].(string) } + if driver == "foundation" { + driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. + } dbSql, err = sql.Open(driver, source) case sqlCommon: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() diff --git a/main_test.go b/main_test.go index 3864fcbd..b547534c 100644 --- a/main_test.go +++ b/main_test.go @@ -36,6 +36,9 @@ func init() { case "postgres": fmt.Println("testing postgres...") DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "foundation": + fmt.Println("testing foundation...") + DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") case "mssql": fmt.Println("testing mssql...") DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") @@ -445,6 +448,14 @@ func TestHaving(t *testing.T) { } } +func DialectHasTzSupport() bool { + // NB: mssql and FoundationDB do not support time zones. + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { + return false + } + return true +} + func TestTimeWithZone(t *testing.T) { var format = "2006-01-02 15:04:05 -0700" var times []time.Time @@ -456,26 +467,30 @@ func TestTimeWithZone(t *testing.T) { name := "time_with_zone_" + strconv.Itoa(index) user := User{Name: name, Birthday: vtime} - // TODO mssql does not support time zones - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { + if !DialectHasTzSupport() { + // If our driver dialect doesn't support TZ's, just use UTC for everything here. user.Birthday = vtime.UTC() } + DB.Save(&user) - if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { - t.Errorf("User's birthday should not be changed after save") + expectedBirthday := "2013-02-18 17:51:49 +0000" + foundBirthday := user.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) } var findUser, findUser2, findUser3 User DB.First(&findUser, "name = ?", name) - if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { - t.Errorf("User's birthday should not be changed after find") + foundBirthday = findUser.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) } - if DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() { + if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { t.Errorf("User should be found") } - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() { + if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { t.Errorf("User should not be found") } } diff --git a/test_all.sh b/test_all.sh index 6c5593b3..bd28294d 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "foundation" "mysql" "sqlite") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test From ab48cd222a3b4edf9a1b9db25f5d331dd9ff5a80 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Sat, 28 Mar 2015 14:15:12 -0700 Subject: [PATCH 16/17] `UpdateColumns(...)` no longer triggers save of associated records. --- README.md | 2 +- callback_shared.go | 6 ++++++ main.go | 1 + scope.go | 8 ++++++++ update_test.go | 29 +++++++++++++++++++++++++++++ 5 files changed, 45 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a6dd1865..c855c591 100644 --- a/README.md +++ b/README.md @@ -360,7 +360,7 @@ db.Model(&user).Updates(User{Name: "hello", Age: 18}) ### Update Without Callbacks -By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks: +By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations: ```go db.Model(&user).UpdateColumn("name", "hello") diff --git a/callback_shared.go b/callback_shared.go index ae75c250..88158cfc 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -11,6 +11,9 @@ func CommitOrRollbackTransaction(scope *Scope) { } func SaveBeforeAssociations(scope *Scope) { + if !scope.shouldSaveAssociations() { + return + } for _, field := range scope.Fields() { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { @@ -25,6 +28,9 @@ func SaveBeforeAssociations(scope *Scope) { } func SaveAfterAssociations(scope *Scope) { + if !scope.shouldSaveAssociations() { + return + } for _, field := range scope.Fields() { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && diff --git a/main.go b/main.go index e197a99c..4a029512 100644 --- a/main.go +++ b/main.go @@ -266,6 +266,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { func (s *DB) UpdateColumns(values interface{}) *DB { return s.clone().NewScope(s.Value). Set("gorm:update_column", true). + Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callback.updates).db } diff --git a/scope.go b/scope.go index fccc5b88..d8e39348 100644 --- a/scope.go +++ b/scope.go @@ -398,3 +398,11 @@ func (scope *Scope) changeableField(field *Field) bool { return !field.IsIgnored } + +func (scope *Scope) shouldSaveAssociations() bool { + saveAssociations, ok := scope.Get("gorm:save_associations") + if ok && !saveAssociations.(bool) { + return false + } + return true +} diff --git a/update_test.go b/update_test.go index 8a019087..9a0af806 100644 --- a/update_test.go +++ b/update_test.go @@ -382,3 +382,32 @@ func TestOmitWithUpdateColumn(t *testing.T) { t.Errorf("Should omit name column when update user") } } + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := getPreparedUser("update_columns_user", "special_role") + user.Age = 99 + address1 := "first street" + user.BillingAddress = Address{Address1: address1} + DB.Save(user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := int64(100) + user.BillingAddress.Address1 = "second street" + db := DB.Model(user).UpdateColumns(User{Age: newAge}) + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) + } + + // Verify that Age now=`newAge`. + freshUser := &User{Id: user.Id} + DB.First(freshUser) + if freshUser.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) + } + + // Verify that user's BillingAddress.Address1 is not changed and is still "first street". + DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) + if freshUser.BillingAddress.Address1 != address1 { + t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) + } +} From 3bd7cab8d67c272471d9ead0b7b6bc32f86ec98a Mon Sep 17 00:00:00 2001 From: Tim Kluge Date: Sun, 29 Mar 2015 20:47:35 +0200 Subject: [PATCH 17/17] Fix type of OwnerType in polymorphism example --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a6dd1865..58d5b879 100644 --- a/README.md +++ b/README.md @@ -546,7 +546,7 @@ Supports polymorphic has-many and has-one associations. Id int Name string OwnerId int - OwnerType int + OwnerType string } ``` Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.