Add getPrimaryKey: analize the tag string in the struct fields and find the one marked as primaryKey

Add primaryKey field to scope and uses getPrimaryKey to find the one marked in that way, if present. Otherwise fallback to id

Format code with gofmt

Fixes getPrimaryKey for non struct type

Add tests

add Tests for update a struct
This commit is contained in:
Paolo Galeone 2014-04-02 11:00:07 +02:00
parent 5574d47f1f
commit fdd9a52800
3 changed files with 149 additions and 10 deletions

View File

@ -7,9 +7,9 @@ import (
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/nerdzeu/gorm"
"os"
"reflect"
@ -127,6 +127,13 @@ type Product struct {
AfterDeleteCallTimes int64
}
type Animal struct {
Counter int64 `primaryKey:"yes"`
Name string
CreatedAt time.Time
UpdatedAt time.Time
}
var (
db gorm.DB
t1, t2, t3, t4, t5 time.Time
@ -170,6 +177,11 @@ func init() {
db.Exec("drop table credit_cards")
db.Exec("drop table roles")
db.Exec("drop table companies")
db.Exec("drop table animals")
if err = db.CreateTable(&Animal{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err = db.CreateTable(&User{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
@ -210,6 +222,11 @@ func init() {
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
db.Save(&User{Name: "5", Age: 26, Birthday: t4})
db.Save(&Animal{Name: "First"})
db.Save(&Animal{Name: "Amazing"})
db.Save(&Animal{Name: "Horse"})
db.Save(&Animal{Name: "Last"})
}
func TestFirstAndLast(t *testing.T) {
@ -230,6 +247,24 @@ func TestFirstAndLast(t *testing.T) {
}
}
func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) {
var animal1, animal2, animal3, animal4 Animal
db.First(&animal1)
db.Order("counter").Find(&animal2)
db.Last(&animal3)
db.Order("counter desc").Find(&animal4)
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
t.Errorf("First and Last should works correctly")
}
var animals []Animal
db.First(&animals)
if len(animals) != 1 {
t.Errorf("Find first record as map")
}
}
func TestSaveCustomType(t *testing.T) {
var user, user1 User
db.First(&user, "name = ?", "1")
@ -935,17 +970,31 @@ func TestSetTableDirectly(t *testing.T) {
func TestUpdate(t *testing.T) {
product1 := Product{Code: "123"}
product2 := Product{Code: "234"}
animal1 := Animal{Name: "Ferdinand"}
animal2 := Animal{Name: "nerdz"}
db.Save(&product1).Save(&product2).Update("code", "456")
if product2.Code != "456" {
t.Errorf("Record should be updated with update attributes")
}
db.Save(&animal1).Save(&animal2).Update("name", "Francis")
if animal2.Name != "Francis" {
t.Errorf("Record should be updated with update attributes")
}
db.First(&product1, product1.Id)
db.First(&product2, product2.Id)
updated_at1 := product1.UpdatedAt
updated_at2 := product2.UpdatedAt
db.First(&animal1, animal1.Counter)
db.First(&animal2, animal2.Counter)
animalUpdated_at1 := animal1.UpdatedAt
animalUpdated_at2 := animal2.UpdatedAt
var product3 Product
db.First(&product3, product2.Id).Update("code", "456")
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
@ -964,6 +1013,25 @@ func TestUpdate(t *testing.T) {
t.Errorf("Product 234 should be changed to 456")
}
var animal3 Animal
db.First(&animal3, animal2.Counter).Update("Name", "Robert")
if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updated_at should not be updated if nothing changed")
}
if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil {
t.Errorf("Animal 'Ferdinand' should not be updated")
}
if db.First(&Animal{}, "name = 'nerdz'").Error == nil {
t.Errorf("Animal 'nerdz' should be changed to 'Francis'")
}
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
t.Errorf("Animal 'nerdz' should be changed to 'Robert'")
}
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
var product4 Product
@ -991,6 +1059,34 @@ func TestUpdate(t *testing.T) {
if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update_column with CamelCase")
}
db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz")
var animal4 Animal
db.First(&animal4, animal1.Counter)
if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("animalUpdated_at should be updated if something changed")
}
if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil {
t.Errorf("Animal 'Fredinand' should be changed to 'Franz'")
}
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
t.Errorf("Animal 'Robert' should not be changed to 'Francis'")
}
if db.First(&Animal{}, "name = 'Franz'").Error != nil {
t.Errorf("Product 'nerdz' should be changed to 'Franz'")
}
if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update with CamelCase")
}
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update_column with CamelCase")
}
}
func TestUpdates(t *testing.T) {

View File

@ -13,13 +13,14 @@ import (
)
type Scope struct {
Value interface{}
Search *search
Sql string
SqlVars []interface{}
db *DB
_values map[string]interface{}
skipLeft bool
Value interface{}
Search *search
Sql string
SqlVars []interface{}
db *DB
_values map[string]interface{}
skipLeft bool
primaryKey string
}
// NewScope create scope for callbacks, including DB's search information
@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool {
// PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string {
return "id"
if scope.primaryKey != "" {
return scope.primaryKey
}
scope.primaryKey = scope.getPrimaryKey()
return scope.primaryKey
}
// PrimaryKeyZero check the primary key is blank or not
@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field {
value := indirectValue.FieldByName(fieldStruct.Name)
field.Value = value.Interface()
field.IsBlank = isBlank(value)
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
// Search for primary key tag identifier
field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != ""
if field.isPrimaryKey {
scope.primaryKey = field.DBName
}
if scope.db != nil {
field.Tag = fieldStruct.Tag

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"go/ast"
"reflect"
"regexp"
"strconv"
@ -472,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope {
}
return scope
}
func (scope *Scope) getPrimaryKey() string {
var indirectValue reflect.Value
indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value))
if indirectValue.Kind() == reflect.Slice {
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
}
if !indirectValue.IsValid() {
return "id"
}
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
// if primaryKey tag found, return column name
if fieldStruct.Tag.Get("primaryKey") != "" {
return toSnake(fieldStruct.Name)
}
}
//If primaryKey tag not found, fallback to id
return "id"
}