Add Update, Updates back
This commit is contained in:
parent
ea67d1d377
commit
cc03f438ef
112
do.go
112
do.go
@ -13,13 +13,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Do struct {
|
type Do struct {
|
||||||
chain *Chain
|
|
||||||
db *DB
|
db *DB
|
||||||
guessedTableName string
|
model *Model
|
||||||
specifiedTableName string
|
tableName string
|
||||||
startedTransaction bool
|
startedTransaction bool
|
||||||
|
|
||||||
model *Model
|
|
||||||
value interface{}
|
value interface{}
|
||||||
sql string
|
sql string
|
||||||
sqlVars []interface{}
|
sqlVars []interface{}
|
||||||
@ -36,18 +34,20 @@ type Do struct {
|
|||||||
ignoreProtectedAttrs bool
|
ignoreProtectedAttrs bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) tableName() string {
|
func (s *Do) table() string {
|
||||||
if len(s.specifiedTableName) == 0 {
|
if len(s.tableName) == 0 {
|
||||||
s.guessedTableName = s.model.tableName()
|
if len(s.db.search.tableName) == 0 {
|
||||||
return s.guessedTableName
|
s.tableName = s.model.tableName()
|
||||||
} else {
|
} else {
|
||||||
return s.specifiedTableName
|
s.tableName = s.db.search.tableName
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return s.tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) err(err error) error {
|
func (s *Do) err(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.chain.err(err)
|
s.db.err(err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -60,18 +60,18 @@ func (s *Do) setModel(value interface{}) *Do {
|
|||||||
|
|
||||||
func (s *Do) addToVars(value interface{}) string {
|
func (s *Do) addToVars(value interface{}) string {
|
||||||
s.sqlVars = append(s.sqlVars, value)
|
s.sqlVars = append(s.sqlVars, value)
|
||||||
return fmt.Sprintf(s.chain.d.dialect.BinVar(), len(s.sqlVars))
|
return fmt.Sprintf(s.db.dialect.BinVar(), len(s.sqlVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) exec(sqls ...string) (err error) {
|
func (s *Do) exec(sqls ...string) (err error) {
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
if len(sqls) > 0 {
|
if len(sqls) > 0 {
|
||||||
s.sql = sqls[0]
|
s.sql = sqls[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, err = s.db.Exec(s.sql, s.sqlVars...)
|
_, err = s.db.db.Exec(s.sql, s.sqlVars...)
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
}
|
}
|
||||||
return s.err(err)
|
return s.err(err)
|
||||||
}
|
}
|
||||||
@ -95,17 +95,17 @@ func (s *Do) prepareCreateSql() {
|
|||||||
|
|
||||||
s.sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||||
s.tableName(),
|
s.table(),
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
s.chain.d.dialect.ReturningStr(s.model.primaryKeyDb()),
|
s.db.dialect.ReturningStr(s.model.primaryKeyDb()),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) saveBeforeAssociations() {
|
func (s *Do) saveBeforeAssociations() {
|
||||||
for _, field := range s.model.beforeAssociations() {
|
for _, field := range s.model.beforeAssociations() {
|
||||||
do := &Do{chain: s.chain, db: s.db}
|
do := &Do{db: s.db}
|
||||||
|
|
||||||
reflect_value := reflect.ValueOf(field.Value)
|
reflect_value := reflect.ValueOf(field.Value)
|
||||||
if reflect_value.CanAddr() {
|
if reflect_value.CanAddr() {
|
||||||
@ -134,7 +134,7 @@ func (s *Do) saveAfterAssociations() {
|
|||||||
switch reflect_value.Kind() {
|
switch reflect_value.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < reflect_value.Len(); i++ {
|
for i := 0; i < reflect_value.Len(); i++ {
|
||||||
do := &Do{chain: s.chain, db: s.db}
|
do := &Do{db: s.db}
|
||||||
|
|
||||||
value := reflect_value.Index(i).Addr().Interface()
|
value := reflect_value.Index(i).Addr().Interface()
|
||||||
if len(field.foreignKey) > 0 {
|
if len(field.foreignKey) > 0 {
|
||||||
@ -143,7 +143,7 @@ func (s *Do) saveAfterAssociations() {
|
|||||||
do.setModel(value).save()
|
do.setModel(value).save()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
do := &Do{chain: s.chain, db: s.db}
|
do := &Do{db: s.db}
|
||||||
if reflect_value.CanAddr() {
|
if reflect_value.CanAddr() {
|
||||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
||||||
do.setModel(field.Value).save()
|
do.setModel(field.Value).save()
|
||||||
@ -170,21 +170,21 @@ func (s *Do) create() (i interface{}) {
|
|||||||
s.saveBeforeAssociations()
|
s.saveBeforeAssociations()
|
||||||
s.prepareCreateSql()
|
s.prepareCreateSql()
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
var id interface{}
|
var id interface{}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if s.chain.d.dialect.SupportLastInsertId() {
|
if s.db.dialect.SupportLastInsertId() {
|
||||||
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
||||||
id, err = sql_result.LastInsertId()
|
id, err = sql_result.LastInsertId()
|
||||||
s.err(err)
|
s.err(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||||
}
|
}
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
|
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
|
||||||
|
|
||||||
s.saveAfterAssociations()
|
s.saveAfterAssociations()
|
||||||
@ -236,7 +236,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
|||||||
|
|
||||||
s.sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"UPDATE %v SET %v %v",
|
"UPDATE %v SET %v %v",
|
||||||
s.tableName(),
|
s.table(),
|
||||||
strings.Join(sqls, ", "),
|
strings.Join(sqls, ", "),
|
||||||
s.combinedSql(),
|
s.combinedSql(),
|
||||||
)
|
)
|
||||||
@ -258,7 +258,7 @@ func (s *Do) update() *Do {
|
|||||||
s.saveBeforeAssociations()
|
s.saveBeforeAssociations()
|
||||||
s.prepareUpdateSql(update_attrs)
|
s.prepareUpdateSql(update_attrs)
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
s.exec()
|
s.exec()
|
||||||
s.saveAfterAssociations()
|
s.saveAfterAssociations()
|
||||||
|
|
||||||
@ -272,11 +272,11 @@ func (s *Do) update() *Do {
|
|||||||
func (s *Do) delete() *Do {
|
func (s *Do) delete() *Do {
|
||||||
s.model.callMethod("BeforeDelete")
|
s.model.callMethod("BeforeDelete")
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||||
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
|
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql())
|
||||||
} else {
|
} else {
|
||||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql())
|
||||||
}
|
}
|
||||||
s.exec()
|
s.exec()
|
||||||
s.model.callMethod("AfterDelete")
|
s.model.callMethod("AfterDelete")
|
||||||
@ -285,7 +285,7 @@ func (s *Do) delete() *Do {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareQuerySql() {
|
func (s *Do) prepareQuerySql() {
|
||||||
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -358,10 +358,10 @@ func (s *Do) query() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
|
|
||||||
if s.err(err) != nil {
|
if s.err(err) != nil {
|
||||||
return
|
return
|
||||||
@ -402,10 +402,10 @@ func (s *Do) query() {
|
|||||||
|
|
||||||
func (s *Do) count(value interface{}) {
|
func (s *Do) count(value interface{}) {
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -420,10 +420,10 @@ func (s *Do) pluck(column string, value interface{}) {
|
|||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.db.hasError() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
|
|
||||||
if s.err(err) == nil {
|
if s.err(err) == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
@ -645,25 +645,25 @@ func (s *Do) createTable() *Do {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.tableName(), strings.Join(sqls, ","))
|
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
|
||||||
|
|
||||||
s.exec()
|
s.exec()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) dropTable() *Do {
|
func (s *Do) dropTable() *Do {
|
||||||
s.sql = fmt.Sprintf("DROP TABLE %v", s.tableName())
|
s.sql = fmt.Sprintf("DROP TABLE %v", s.table())
|
||||||
s.exec()
|
s.exec()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) updateColumn(column string, typ string) {
|
func (s *Do) updateColumn(column string, typ string) {
|
||||||
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.tableName(), column, typ)
|
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) dropColumn(column string) {
|
func (s *Do) dropColumn(column string) {
|
||||||
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.tableName(), column)
|
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -672,22 +672,22 @@ func (s *Do) addIndex(column string, names ...string) {
|
|||||||
if len(names) > 0 {
|
if len(names) > 0 {
|
||||||
index_name = names[0]
|
index_name = names[0]
|
||||||
} else {
|
} else {
|
||||||
index_name = fmt.Sprintf("index_%v_on_%v", s.tableName(), column)
|
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.tableName(), column)
|
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) removeIndex(index_name string) {
|
func (s *Do) removeIndex(index_name string) {
|
||||||
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.tableName())
|
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table())
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) autoMigrate() *Do {
|
func (s *Do) autoMigrate() *Do {
|
||||||
var table_name string
|
var table_name string
|
||||||
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.tableName()))
|
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table()))
|
||||||
s.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
|
s.db.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
|
||||||
s.sqlVars = []interface{}{}
|
s.sqlVars = []interface{}{}
|
||||||
|
|
||||||
// If table doesn't exist
|
// If table doesn't exist
|
||||||
@ -696,13 +696,13 @@ func (s *Do) autoMigrate() *Do {
|
|||||||
} else {
|
} else {
|
||||||
for _, field := range s.model.fields("migration") {
|
for _, field := range s.model.fields("migration") {
|
||||||
var column_name, data_type string
|
var column_name, data_type string
|
||||||
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
|
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.table()))
|
||||||
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
|
s.db.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
|
||||||
s.sqlVars = []interface{}{}
|
s.sqlVars = []interface{}{}
|
||||||
|
|
||||||
// If column doesn't exist
|
// If column doesn't exist
|
||||||
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
|
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
|
||||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.dbName, field.sqlTag())
|
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag())
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -711,9 +711,9 @@ func (s *Do) autoMigrate() *Do {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) begin() *Do {
|
func (s *Do) begin() *Do {
|
||||||
if db, ok := s.db.(sqlDb); ok {
|
if db, ok := s.db.db.(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); err == nil {
|
||||||
s.db = interface{}(tx).(sqlCommon)
|
s.db.db = interface{}(tx).(sqlCommon)
|
||||||
s.startedTransaction = true
|
s.startedTransaction = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -722,8 +722,8 @@ func (s *Do) begin() *Do {
|
|||||||
|
|
||||||
func (s *Do) commit_or_rollback() {
|
func (s *Do) commit_or_rollback() {
|
||||||
if s.startedTransaction {
|
if s.startedTransaction {
|
||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.db.(sqlTx); ok {
|
||||||
if s.chain.hasError() {
|
if s.db.hasError() {
|
||||||
db.Rollback()
|
db.Rollback()
|
||||||
} else {
|
} else {
|
||||||
db.Commit()
|
db.Commit()
|
||||||
|
13
main.go
13
main.go
@ -109,7 +109,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||||||
s.clone().do(out).where(where).initialize()
|
s.clone().do(out).where(where).initialize()
|
||||||
} else {
|
} else {
|
||||||
if len(s.search.assignAttrs) > 0 {
|
if len(s.search.assignAttrs) > 0 {
|
||||||
s.do(out).updateAttrs(s.assignAttrs) //updated or not
|
s.do(out).updateAttrs(s.search.assignAttrs) //updated or not
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
@ -127,13 +127,22 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
|
return s.Updates(toSearchableMap(attrs...), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
|
||||||
|
s.do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
s.do(value).begin().save().commit_or_rollback()
|
s.do(value).begin().save().commit_or_rollback()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Delete(value interface{}) *DB {
|
func (s *DB) Delete(value interface{}) *DB {
|
||||||
s.do(value).bengin().delete(value).commit_or_rollback()
|
s.do(value).begin().delete().commit_or_rollback()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,13 +29,13 @@ func (s *DB) hasError() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) print(level string, v ...interface{}) {
|
func (s *DB) print(level string, v ...interface{}) {
|
||||||
if s.d.logMode || s.debug_mode || level == "debug" {
|
if s.logMode || level == "debug" {
|
||||||
if _, ok := s.d.logger.(Logger); !ok {
|
if _, ok := s.parent.logger.(Logger); !ok {
|
||||||
fmt.Println("logger haven't been set, using os.Stdout")
|
fmt.Println("logger haven't been set, using os.Stdout")
|
||||||
s.d.logger = default_logger
|
s.parent.logger = default_logger
|
||||||
}
|
}
|
||||||
args := []interface{}{level}
|
args := []interface{}{level}
|
||||||
s.d.logger.(Logger).Print(append(args, v...)...)
|
s.parent.logger.(Logger).Print(append(args, v...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user