gorm/scope.go
Paolo Galeone fdd9a52800 Add getPrimaryKey: analize the tag string in the struct fields and find the one marked as primaryKey
Add primaryKey field to scope and uses getPrimaryKey to find the one marked in that way, if present. Otherwise fallback to id

Format code with gofmt

Fixes getPrimaryKey for non struct type

Add tests

add Tests for update a struct
2014-04-06 03:59:48 +02:00

356 lines
8.6 KiB
Go

package gorm
import (
"errors"
"fmt"
"github.com/jinzhu/gorm/dialect"
"go/ast"
"strings"
"time"
"reflect"
"regexp"
)
type Scope struct {
Value interface{}
Search *search
Sql string
SqlVars []interface{}
db *DB
_values map[string]interface{}
skipLeft bool
primaryKey string
}
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
db.Value = value
return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}}
}
// New create a new Scope without search information
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.db.parent, Search: &search{}, Value: value}
}
// NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB {
return scope.db.new()
}
// DB get *sql.DB
func (scope *Scope) DB() sqlCommon {
return scope.db.db
}
// SkipLeft skip remaining callbacks
func (scope *Scope) SkipLeft() {
scope.skipLeft = true
}
// Quote used to quote database column name according to database dialect
func (scope *Scope) Quote(str string) string {
return scope.Dialect().Quote(str)
}
// Dialect get dialect
func (scope *Scope) Dialect() dialect.Dialect {
return scope.db.parent.dialect
}
// Err write error
func (scope *Scope) Err(err error) error {
if err != nil {
scope.db.err(err)
}
return err
}
// Log print log message
func (scope *Scope) Log(v ...interface{}) {
scope.db.log(v...)
}
// HasError check if there are any error
func (scope *Scope) HasError() bool {
return scope.db.Error != nil
}
// PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string {
if scope.primaryKey != "" {
return scope.primaryKey
}
scope.primaryKey = scope.getPrimaryKey()
return scope.primaryKey
}
// PrimaryKeyZero check the primary key is blank or not
func (scope *Scope) PrimaryKeyZero() bool {
return isBlank(reflect.ValueOf(scope.PrimaryKeyValue()))
}
// PrimaryKeyValue get the primary key's value
func (scope *Scope) PrimaryKeyValue() interface{} {
data := reflect.Indirect(reflect.ValueOf(scope.Value))
if data.Kind() == reflect.Struct {
if field := data.FieldByName(snakeToUpperCamel(scope.PrimaryKey())); field.IsValid() {
return field.Interface()
}
}
return 0
}
// HasColumn to check if has column
func (scope *Scope) HasColumn(name string) bool {
_, result := scope.FieldByName(name)
return result
}
// FieldByName to get column's value and existence
func (scope *Scope) FieldByName(name string) (interface{}, bool) {
data := reflect.Indirect(reflect.ValueOf(scope.Value))
if data.Kind() == reflect.Struct {
if field := data.FieldByName(name); field.IsValid() {
return field.Interface(), true
}
} else if data.Kind() == reflect.Slice {
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
}
return nil, false
}
// SetColumn to set the column's value
func (scope *Scope) SetColumn(column string, value interface{}) {
if scope.Value == nil {
return
}
data := reflect.Indirect(reflect.ValueOf(scope.Value))
setFieldValue(data.FieldByName(snakeToUpperCamel(column)), value)
}
// CallMethod invoke method with necessary argument
func (scope *Scope) CallMethod(name string) {
if scope.Value == nil {
return
}
call := func(value interface{}) {
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
fi := fm.Interface()
if f, ok := fi.(func()); ok {
f()
} else if f, ok := fi.(func(s *Scope)); ok {
f(scope)
} else if f, ok := fi.(func(s *DB)); ok {
f(scope.db.new())
} else if f, ok := fi.(func() error); ok {
scope.Err(f())
} else if f, ok := fi.(func(s *Scope) error); ok {
scope.Err(f(scope))
} else if f, ok := fi.(func(s *DB) error); ok {
scope.Err(f(scope.db.new()))
} else {
scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name)))
}
}
}
if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice {
for i := 0; i < values.Len(); i++ {
call(values.Index(i).Addr().Interface())
}
} else {
call(scope.Value)
}
}
// AddToVars add value as sql's vars, gorm will escape them
func (scope *Scope) AddToVars(value interface{}) string {
scope.SqlVars = append(scope.SqlVars, value)
return scope.Dialect().BinVar(len(scope.SqlVars))
}
// TableName get table name
func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName
} else {
if scope.Value == nil {
scope.Err(errors.New("can't get table name"))
return ""
}
data := reflect.Indirect(reflect.ValueOf(scope.Value))
if data.Kind() == reflect.Slice {
data = reflect.New(data.Type().Elem()).Elem()
}
if fm := data.MethodByName("TableName"); fm.IsValid() {
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if result, ok := v[0].Interface().(string); ok {
return result
}
}
}
str := toSnake(data.Type().Name())
if !scope.db.parent.singularTable {
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
for key, value := range pluralMap {
reg := regexp.MustCompile(key + "$")
if reg.MatchString(str) {
return reg.ReplaceAllString(str, value)
}
}
}
return str
}
}
// CombinedConditionSql get combined condition sql
func (scope *Scope) CombinedConditionSql() string {
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
}
// Fields get value's fields
func (scope *Scope) Fields() []*Field {
indirectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
fields := []*Field{}
if !indirectValue.IsValid() {
return fields
}
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
var field Field
field.Name = fieldStruct.Name
field.DBName = toSnake(fieldStruct.Name)
value := indirectValue.FieldByName(fieldStruct.Name)
field.Value = value.Interface()
field.IsBlank = isBlank(value)
// Search for primary key tag identifier
field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != ""
if field.isPrimaryKey {
scope.primaryKey = field.DBName
}
if scope.db != nil {
field.Tag = fieldStruct.Tag
field.SqlTag = scope.sqlTagForField(&field)
// parse association
elem := reflect.Indirect(value)
typ := elem.Type()
switch elem.Kind() {
case reflect.Slice:
typ = typ.Elem()
if typ.Kind() == reflect.Struct {
foreignKey := scopeTyp.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey
}
field.AfterAssociation = true
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if scope.HasColumn(field.Name + "Id") {
field.ForeignKey = field.Name + "Id"
field.BeforeAssociation = true
} else {
foreignKey := scopeTyp.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey
}
field.AfterAssociation = true
}
}
}
}
fields = append(fields, &field)
}
return fields
}
// Raw set sql
func (scope *Scope) Raw(sql string) *Scope {
scope.Sql = strings.Replace(sql, "$$", "?", -1)
return scope
}
// Exec invoke sql
func (scope *Scope) Exec() *Scope {
defer scope.Trace(time.Now())
if !scope.HasError() {
_, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
scope.Err(err)
}
return scope
}
// Set set value by name
func (scope *Scope) Set(name string, value interface{}) *Scope {
scope._values[name] = value
return scope
}
// Get get value by name
func (scope *Scope) Get(name string) (value interface{}, ok bool) {
value, ok = scope._values[name]
return
}
// Trace print sql log
func (scope *Scope) Trace(t time.Time) {
if len(scope.Sql) > 0 {
scope.db.slog(scope.Sql, t, scope.SqlVars...)
}
}
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.DB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.Set("gorm:started_transaction", true)
}
}
return scope
}
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
func (scope *Scope) CommitOrRollback() *Scope {
if _, ok := scope.Get("gorm:started_transaction"); ok {
if db, ok := scope.db.db.(sqlTx); ok {
if scope.HasError() {
db.Rollback()
} else {
db.Commit()
}
scope.db.db = scope.db.parent.db
}
}
return scope
}