diff --git a/api.go b/api.go new file mode 100644 index 00000000..6bd9cb5d --- /dev/null +++ b/api.go @@ -0,0 +1,304 @@ +package gorm + +import ( + "github.com/jinzhu/gorm/builder" + "github.com/jinzhu/gorm/dialects" +) + +// Where add condition +func (s *DB) Where(query interface{}, args ...interface{}) *DB { + tx := s.init() + tx.Statement.AddConditions(tx.Statement.BuildCondition(query, args...)) + return tx +} + +// Not add NOT condition +func (s *DB) Not(query interface{}, args ...interface{}) *DB { + tx := s.init() + tx.Statement.AddConditions(builder.Not(tx.Statement.BuildCondition(query, args...))) + return tx +} + +// And add AND conditions +func (s *DB) And(conds ...builder.Condition) *DB { + tx := s.init() + tx.Statement.AddConditions(builder.And(conds)) + return tx +} + +// Or add OR conditions +func (s *DB) Or(conds ...builder.Condition) *DB { + tx := s.init() + tx.Statement.AddConditions(builder.Or(conds)) + return tx +} + +// Joins specify Joins conditions +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (s *DB) Joins(query string, args ...interface{}) *DB { + tx := s.init() + // FIXME + tx.Statement.Joins = append(tx.Statement.Joins, builder.Join{Conditions: []builder.Condition{tx.Statement.BuildCondition(query, args...)}}) + return tx +} + +// Group specify the group method on the find +func (s *DB) Group(column string) *DB { + tx := s.init() + tx.Statement.GroupBy.GroupByColumns = append(tx.Statement.GroupBy.GroupByColumns, column) + return tx +} + +// Having specify HAVING conditions for GROUP BY +func (s *DB) Having(query interface{}, args ...interface{}) *DB { + tx := s.init() + tx.Statement.GroupBy.Having = append(tx.Statement.GroupBy.Having, tx.Statement.BuildCondition(query, args...)) + return tx +} + +// Order specify order when retrieve records from database +// db.Order("name DESC") +// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +func (s *DB) Order(value interface{}) *DB { + tx := s.init() + tx.Statement.OrderBy = append(tx.Statement.OrderBy, value) + return tx +} + +// Reorder works like Order, but will overwrite current order information +func (s *DB) Reorder(value interface{}) *DB { + tx := s.init() + tx.Statement.OrderBy = []builder.OrderCondition{value} + return tx +} + +// Limit specify the number of records to be retrieved +func (s *DB) Limit(limit int64) *DB { + tx := s.init() + if limit < 0 { + tx.Statement.Limit.Limit = nil + } else { + tx.Statement.Limit.Limit = &limit + } + return tx +} + +// Offset specify the number of records to skip before starting to return the records +func (s *DB) Offset(offset int64) *DB { + tx := s.init() + if offset < 0 { + tx.Statement.Limit.Offset = nil + } else { + tx.Statement.Limit.Offset = &offset + } + return tx +} + +// Select specify fields that you want when querying, creating, updating +func (s *DB) Select(query interface{}, args ...interface{}) *DB { + tx := s.init() + + switch value := query.(type) { + case string: + tx.Statement.Select.Columns = []string{value} + case []string: + tx.Statement.Select.Columns = value + default: + tx.AddError(ErrUnsupportedSelect) + } + tx.Statement.Select.Args = args + return tx +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (s *DB) Omit(columns ...string) *DB { + tx := s.init() + tx.Statement.Omit = columns + return tx +} + +// First find first record that match given conditions, order by primary key +func (s *DB) First(out interface{}, where ...interface{}) *DB { + conds := []interface{}{builder.Limit{Limit: &one}, builder.Settings{"gorm:order_by_primary_key": "ASC"}} + if len(where) > 0 { + conds = append(conds, s.Statement.BuildCondition(where[0], where[1:]...)) + } + return s.Find(out, conds...) +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (s *DB) Take(out interface{}, where ...interface{}) *DB { + conds := []interface{}{builder.Limit{Limit: &one}} + if len(where) > 0 { + conds = append(conds, s.Statement.BuildCondition(where[0], where[1:]...)) + } + return s.Find(out, conds...) +} + +// Last find last record that match given conditions, order by primary key +func (s *DB) Last(out interface{}, where ...interface{}) *DB { + conds := []interface{}{builder.Limit{Limit: &one}, builder.Settings{"gorm:order_by_primary_key": "DESC"}} + if len(where) > 0 { + conds = append(conds, s.Statement.BuildCondition(where[0], where[1:]...)) + } + return s.Find(out, conds...) +} + +// Find find records that match given conditions +func (s *DB) Find(out interface{}, where ...interface{}) *DB { + tx := s.init() + stmt := tx.Statement + stmt.Dest = out + + if len(where) > 0 { + stmt = s.Statement.Clone() + stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...)) + } + + tx.AddError(tx.Dialect().Query(stmt)) + return tx +} + +// Scan scan value to a struct +func (s *DB) Scan(dest interface{}) *DB { + var ( + tx = s.init() + stmt = tx.Statement.Clone() + ) + + stmt.Table = stmt.Dest + stmt.Dest = dest + tx.AddError(tx.Dialect().Query(stmt)) + return tx +} + +// Create insert the value into database +func (s *DB) Create(value interface{}) *DB { + tx := s.init() + tx.Statement.Dest = value + tx.AddError(tx.Dialect().Insert(tx.Statement)) + return tx +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (s *DB) Save(value interface{}) *DB { + tx := s.init() + tx.Statement.Dest = value + // FIXME check primary key has value or not + tx.AddError(tx.Dialect().Update(tx.Statement)) + return tx +} + +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +func (s *DB) Update(column string, value interface{}) *DB { + tx := s.init() + tx.Statement.Assignments = append(tx.Statement.Assignments, builder.Assignment{Column: column, Value: value}) + tx.AddError(tx.Dialect().Update(tx.Statement)) + return tx +} + +// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +func (s *DB) Updates(values interface{}) *DB { + tx := s.init() + tx.Statement.Assignments = append(tx.Statement.Assignments, builder.Assignment{Value: values}) + tx.AddError(tx.Dialect().Update(tx.Statement)) + return tx +} + +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +func (s *DB) Delete(value interface{}, where ...interface{}) *DB { + tx := s.init() + stmt := tx.Statement + stmt.Dest = value + + if len(where) > 0 { + stmt = s.Statement.Clone() + stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...)) + } + + tx.AddError(tx.Dialect().Update(stmt)) + return tx +} + +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (s *DB) Model(value interface{}) *DB { + tx := s.init() + tx.Statement.Dest = value + return tx +} + +// Table specify the table you would like to run db operations +func (s *DB) Table(name string) *DB { + tx := s.init() + tx.Statement.Table = name + return tx +} + +// AddError add error to the db +func (s *DB) AddError(err error) { + if err != nil { + if err != ErrRecordNotFound { + s.Config.Logger.Error(err) + } + + if errs := s.GetErrors(); len(errs) == 0 { + s.Error = err + } else { + s.Error = Errors(errs).Add(err) + } + } +} + +// GetErrors get happened errors from the db +func (s *DB) GetErrors() []error { + if errs, ok := s.Error.(Errors); ok { + return errs + } else if s.Error != nil { + return []error{s.Error} + } + return []error{} +} + +// Dialect return DB dialect +func (s *DB) Dialect() dialects.Dialect { + if s.TxDialect != nil { + return s.TxDialect + } + return s.Config.Dialect +} + +// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// Refer https://jinzhu.github.io/gorm/crud.html#scopes +func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { + for _, f := range funcs { + s = f(s) + } + return s +} + +// init init DB +func (s *DB) init() *DB { + if s.Statement == nil { + return &DB{ + TxDialect: s.TxDialect, + Statement: &builder.Statement{}, + Config: s.Config, + } + } + return s +} diff --git a/association.go b/association.go deleted file mode 100644 index 8c6d9864..00000000 --- a/association.go +++ /dev/null @@ -1,375 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - if relationship.Kind == "many_to_many" { - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } else if relationship.Kind == "belongs_to" { - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) - } - - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err - } - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/builder/statement.go b/builder/statement.go index cfcd4a0b..cadb6dff 100644 --- a/builder/statement.go +++ b/builder/statement.go @@ -1,21 +1,25 @@ package builder +import "sync" + // Column column type type Column = string // Statement GORM statement type Statement struct { - Dest interface{} // Insert, Select, Update, Delete - Table string // Insert, Select, Update, Delete - Select []interface{} // Select - Joins []Join // Select - GroupBy GroupBy // Select - OrderBy OrderBy // Select - Preload []Column // Select - Limit Limit // Select, Update - Where []Condition // Select, Update, Delete - Assignments []Assignment // Insert, Update - Returnnings []Column // Insert, Update, Delete + Dest interface{} // Insert, Select, Update, Delete + Table interface{} // Insert, Select, Update, Delete + Select SelectColumn // Insert, Select, Update + Omit []Column // Insert, Select, Update + Joins []Join // Select + GroupBy GroupBy // Select + OrderBy OrderBy // Select + Preload []Column // Select + Limit Limit // Select, Update + Conditions []Condition // Select, Update, Delete + Assignments []Assignment // Insert, Update + Returnnings []Column // Insert, Update, Delete + Settings sync.Map } // Condition query condition statement interface @@ -23,6 +27,18 @@ type Condition interface { // ToSQL() } +// Settings settings +type Settings map[string]interface{} + +// DefaultValue default value type +type DefaultValue string + +// SelectColumn select columns +type SelectColumn struct { + Columns []string + Args []interface{} +} + // Join join statement type Join struct { Table string @@ -37,8 +53,11 @@ type GroupBy struct { Having []Condition } +// OrderCondition order condition, could be string or sql expr +type OrderCondition interface{} + // OrderBy order by statement -type OrderBy []OrderByColumn +type OrderBy []OrderCondition // OrderByColumn column used for order type OrderByColumn struct { @@ -48,8 +67,8 @@ type OrderByColumn struct { // Limit limit statement type Limit struct { - Limit int - Offset int + Limit *int64 + Offset *int64 } // Assignment assign statement @@ -57,3 +76,20 @@ type Assignment struct { Column Column Value interface{} } + +// Clone clone current statement +func (stmt *Statement) Clone() *Statement { + newStatement := *stmt + // FIXME fix settings + return &newStatement +} + +// BuildCondition build condition +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) Condition { + return nil +} + +// AddConditions add conditions +func (stmt *Statement) AddConditions(conds ...Condition) { + stmt.Conditions = append(stmt.Conditions, conds...) +} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index 6432cadd..00000000 --- a/dialect.go +++ /dev/null @@ -1,20 +0,0 @@ -package dialects - -import ( - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/builder" -) - -// Dialect GORM dialect interface -type Dialect interface { - // CRUD operations - Insert(*gorm.DB, builder.Statement) error - Query(*gorm.DB, builder.Statement) error - Update(*gorm.DB, builder.Statement) error - Delete(*gorm.DB, builder.Statement) error - - // DB Driver interface - QueryRow(*gorm.DB) error - QueryRows(*gorm.DB) error - Exec(*gorm.DB) error -} diff --git a/dialects/common/utils/statement.go b/dialects/common/utils/statement.go new file mode 100644 index 00000000..e6dc04c4 --- /dev/null +++ b/dialects/common/utils/statement.go @@ -0,0 +1,17 @@ +package utils + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/builder" + "github.com/jinzhu/gorm/model" +) + +// GetCreatingAssignments get creating assignments +func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []model.Field { + return nil +} + +// GetTable get table name +func GetTable(stmt *builder.Statement, errs *gorm.Errors) chan string { + return nil +} diff --git a/dialects/common/utils/utils.go b/dialects/common/utils/utils.go new file mode 100644 index 00000000..c38d23c1 --- /dev/null +++ b/dialects/common/utils/utils.go @@ -0,0 +1,19 @@ +package utils + +// ToSearchableMap convert attrs to searchable map +func ToSearchableMap(attrs ...interface{}) (result interface{}) { + if len(attrs) > 1 { + if str, ok := attrs[0].(string); ok { + result = map[string]interface{}{str: attrs[1]} + } + } else if len(attrs) == 1 { + if attr, ok := attrs[0].(map[string]interface{}); ok { + result = attr + } + + if attr, ok := attrs[0].(interface{}); ok { + result = attr + } + } + return +} diff --git a/dialects/dialect.go b/dialects/dialect.go new file mode 100644 index 00000000..fe1e1d28 --- /dev/null +++ b/dialects/dialect.go @@ -0,0 +1,13 @@ +package dialects + +import ( + "github.com/jinzhu/gorm/builder" +) + +// Dialect GORM dialect interface +type Dialect interface { + Insert(*builder.Statement) error + Query(*builder.Statement) error + Update(*builder.Statement) error + Delete(*builder.Statement) error +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 069ad3a9..5c519d79 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -1,3 +1,55 @@ package sqlite -import _ "github.com/mattn/go-sqlite3" +import ( + "bytes" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/builder" + "github.com/jinzhu/gorm/dialects/common/utils" + + // import sqlite3 driver + _ "github.com/mattn/go-sqlite3" +) + +// Dialect Sqlite3 Dialect for GORM +type Dialect struct { +} + +// Insert insert +func (*Dialect) Insert(stmt *builder.Statement) error { + var ( + args []interface{} + defaultValueColumns []string + errs = gorm.Errors{} + assignmentsChan = utils.GetCreatingAssignments(stmt, &errs) + tableNameChan = utils.GetTable(stmt, &errs) + ) + + s := bytes.NewBufferString("INSERT INTO ") + s.WriteString(<-tableNameChan) + + if assignments := <-assignmentsChan; len(assignments) > 0 { + for column, value := range assignments { + args = append(args, value...) + } + } else { + // assign default value + } + + return nil +} + +// Query query +func (*Dialect) Query(*builder.Statement) error { + return nil +} + +// Update update +func (*Dialect) Update(*builder.Statement) error { + return nil +} + +// Delete delete +func (*Dialect) Delete(*builder.Statement) error { + return nil +} diff --git a/errors.go b/errors.go index da2cf13c..6a1ddd06 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,8 @@ var ( ErrCantStartTransaction = errors.New("can't start transaction") // ErrUnaddressable unaddressable value ErrUnaddressable = errors.New("using unaddressable value") + // ErrUnsupportedSelect unsupported select value + ErrUnsupportedSelect = errors.New("using unsupported select") ) // Errors contains all happened errors diff --git a/gorm.go b/gorm.go index b4e1ed88..e1b3fa3b 100644 --- a/gorm.go +++ b/gorm.go @@ -3,16 +3,15 @@ package gorm import ( "time" + "github.com/jinzhu/gorm/builder" + "github.com/jinzhu/gorm/dialects" "github.com/jinzhu/gorm/logger" ) +var one int64 = 1 + // Config GORM config type Config struct { - // MaxIdleConnections sets the maximum number of connections in the idle connection pool - MaxIdleConnections int - // MaxOpenConnections sets the maximum number of open connections to the database - MaxOpenConnections int - // SingularTable use singular table name, by default, GORM will pluralize your struct's name as table name // Refer https://github.com/jinzhu/inflection for inflection rules SingularTable bool @@ -20,37 +19,28 @@ type Config struct { // BlockGlobalUpdate generates an error on update/delete without where clause, this is to prevent eventual error with empty objects updates/deletions BlockGlobalUpdate bool - // Dialect DB Dialect - Dialect Dialect - - // Callbacks defined GORM callbacks - Callbacks *Callback - // Logger Logger logger.Interface LogMode logger.LogLevel - // db fresh db connection - globalDB SQLCommon + // Dialect DB Dialect + Dialect dialects.Dialect } -// DB contains information for current db connection +// DB GORM DB definition type DB struct { - // current instance - Value interface{} - tx SQLCommon - search *search - values map[string]interface{} + TxDialect dialects.Dialect + Statement *builder.Statement // Global config - config *Config + Config *Config // Result result fields Error error RowsAffected int64 } -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models +// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which can be embedded in your model // type User struct { // gorm.Model // } diff --git a/logger/logger.go b/logger/logger.go index 52306253..113c773b 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -25,7 +25,7 @@ var DefaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} const ( // Info print SQL, warn messages and errors - Info LogLevel = iota - 1 + Info LogLevel = iota + 1 // Warn print warn messages and errors Warn // Error print errors diff --git a/main.go b/main.go deleted file mode 100644 index 0f853d1c..00000000 --- a/main.go +++ /dev/null @@ -1,712 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - case SQLCommon: - dbSQL = value - } - - db = &DB{ - db: dbSQL, - values: map[string]interface{}{}, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, _ := s.db.(*sql.DB) - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.parent.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() - return s.parent.callbacks -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} -} - -// QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Take return a record that match given conditions, the order will depend on the database implementation -func (s *DB) Take(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - scope = s.NewScope(result) - clone = scope.db - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.NewScope(nil) - generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - // FIXME Debug Mode - return s -} - -// Begin begin a transaction -func (s *DB) Begin() *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() - c.db = interface{}(tx).(SQLCommon) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { - s.AddError(db.Rollback()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") -func (s *DB) RemoveForeignKey(field string, dest string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - values: map[string]interface{}{}, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - } - - for key, value := range s.values { - db.values[key] = value - } - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/model/field.go b/model/field.go index 11c410b0..cbf31c06 100644 --- a/model/field.go +++ b/model/field.go @@ -1,4 +1,4 @@ -package gorm +package model import ( "database/sql" diff --git a/dialects/common/join_table_handler.go b/model/join_table_handler.go similarity index 99% rename from dialects/common/join_table_handler.go rename to model/join_table_handler.go index a036d46d..3150cecd 100644 --- a/dialects/common/join_table_handler.go +++ b/model/join_table_handler.go @@ -1,4 +1,4 @@ -package gorm +package model import ( "errors" diff --git a/model/model_struct.go b/model/model_struct.go index f571e2e8..00e95326 100644 --- a/model/model_struct.go +++ b/model/model_struct.go @@ -1,4 +1,4 @@ -package gorm +package model import ( "database/sql" @@ -9,11 +9,12 @@ import ( "sync" "time" + "github.com/jinzhu/gorm/builder" "github.com/jinzhu/inflection" ) // DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { +var DefaultTableNameHandler = func(db *builder.Statement, defaultTableName string) string { return defaultTableName } diff --git a/utils.go b/utils.go deleted file mode 100644 index dfaae939..00000000 --- a/utils.go +++ /dev/null @@ -1,285 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - -// SQL expression -type expr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -}