From fdd9a528005b9e243c8af1ba0ee202f9b413f0e0 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Wed, 2 Apr 2014 11:00:07 +0200 Subject: [PATCH] Add getPrimaryKey: analize the tag string in the struct fields and find the one marked as primaryKey Add primaryKey field to scope and uses getPrimaryKey to find the one marked in that way, if present. Otherwise fallback to id Format code with gofmt Fixes getPrimaryKey for non struct type Add tests add Tests for update a struct --- main_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++- scope.go | 30 ++++++++++----- scope_private.go | 31 +++++++++++++++ 3 files changed, 149 insertions(+), 10 deletions(-) diff --git a/main_test.go b/main_test.go index 5e8e3070..5deb7822 100644 --- a/main_test.go +++ b/main_test.go @@ -7,9 +7,9 @@ import ( "fmt" _ "github.com/go-sql-driver/mysql" - "github.com/jinzhu/gorm" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/nerdzeu/gorm" "os" "reflect" @@ -127,6 +127,13 @@ type Product struct { AfterDeleteCallTimes int64 } +type Animal struct { + Counter int64 `primaryKey:"yes"` + Name string + CreatedAt time.Time + UpdatedAt time.Time +} + var ( db gorm.DB t1, t2, t3, t4, t5 time.Time @@ -170,6 +177,11 @@ func init() { db.Exec("drop table credit_cards") db.Exec("drop table roles") db.Exec("drop table companies") + db.Exec("drop table animals") + + if err = db.CreateTable(&Animal{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } if err = db.CreateTable(&User{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) @@ -210,6 +222,11 @@ func init() { db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4}) db.Save(&User{Name: "5", Age: 26, Birthday: t4}) + + db.Save(&Animal{Name: "First"}) + db.Save(&Animal{Name: "Amazing"}) + db.Save(&Animal{Name: "Horse"}) + db.Save(&Animal{Name: "Last"}) } func TestFirstAndLast(t *testing.T) { @@ -230,6 +247,24 @@ func TestFirstAndLast(t *testing.T) { } } +func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) { + var animal1, animal2, animal3, animal4 Animal + db.First(&animal1) + db.Order("counter").Find(&animal2) + + db.Last(&animal3) + db.Order("counter desc").Find(&animal4) + if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { + t.Errorf("First and Last should works correctly") + } + + var animals []Animal + db.First(&animals) + if len(animals) != 1 { + t.Errorf("Find first record as map") + } +} + func TestSaveCustomType(t *testing.T) { var user, user1 User db.First(&user, "name = ?", "1") @@ -935,17 +970,31 @@ func TestSetTableDirectly(t *testing.T) { func TestUpdate(t *testing.T) { product1 := Product{Code: "123"} product2 := Product{Code: "234"} + animal1 := Animal{Name: "Ferdinand"} + animal2 := Animal{Name: "nerdz"} + db.Save(&product1).Save(&product2).Update("code", "456") if product2.Code != "456" { t.Errorf("Record should be updated with update attributes") } + db.Save(&animal1).Save(&animal2).Update("name", "Francis") + + if animal2.Name != "Francis" { + t.Errorf("Record should be updated with update attributes") + } + db.First(&product1, product1.Id) db.First(&product2, product2.Id) updated_at1 := product1.UpdatedAt updated_at2 := product2.UpdatedAt + db.First(&animal1, animal1.Counter) + db.First(&animal2, animal2.Counter) + animalUpdated_at1 := animal1.UpdatedAt + animalUpdated_at2 := animal2.UpdatedAt + var product3 Product db.First(&product3, product2.Id).Update("code", "456") if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { @@ -964,6 +1013,25 @@ func TestUpdate(t *testing.T) { t.Errorf("Product 234 should be changed to 456") } + var animal3 Animal + db.First(&animal3, animal2.Counter).Update("Name", "Robert") + + if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should not be updated if nothing changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil { + t.Errorf("Animal 'Ferdinand' should not be updated") + } + + if db.First(&Animal{}, "name = 'nerdz'").Error == nil { + t.Errorf("Animal 'nerdz' should be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'nerdz' should be changed to 'Robert'") + } + db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") var product4 Product @@ -991,6 +1059,34 @@ func TestUpdate(t *testing.T) { if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { t.Error("No error should raise when update_column with CamelCase") } + + db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz") + + var animal4 Animal + db.First(&animal4, animal1.Counter) + if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("animalUpdated_at should be updated if something changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil { + t.Errorf("Animal 'Fredinand' should be changed to 'Franz'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'Robert' should not be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Franz'").Error != nil { + t.Errorf("Product 'nerdz' should be changed to 'Franz'") + } + + if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update with CamelCase") + } + + if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update_column with CamelCase") + } } func TestUpdates(t *testing.T) { diff --git a/scope.go b/scope.go index 40697f9a..543c4cde 100644 --- a/scope.go +++ b/scope.go @@ -13,13 +13,14 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - _values map[string]interface{} - skipLeft bool + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} + skipLeft bool + primaryKey string } // NewScope create scope for callbacks, including DB's search information @@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool { // PrimaryKey get the primary key's column name func (scope *Scope) PrimaryKey() string { - return "id" + if scope.primaryKey != "" { + return scope.primaryKey + } + + scope.primaryKey = scope.getPrimaryKey() + return scope.primaryKey } // PrimaryKeyZero check the primary key is blank or not @@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field { value := indirectValue.FieldByName(fieldStruct.Name) field.Value = value.Interface() field.IsBlank = isBlank(value) - field.isPrimaryKey = scope.PrimaryKey() == field.DBName + + // Search for primary key tag identifier + field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != "" + + if field.isPrimaryKey { + scope.primaryKey = field.DBName + } if scope.db != nil { field.Tag = fieldStruct.Tag diff --git a/scope_private.go b/scope_private.go index 72f631cc..d9f16438 100644 --- a/scope_private.go +++ b/scope_private.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "go/ast" "reflect" "regexp" "strconv" @@ -472,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope { } return scope } + +func (scope *Scope) getPrimaryKey() string { + var indirectValue reflect.Value + + indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value)) + + if indirectValue.Kind() == reflect.Slice { + indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() + } + + if !indirectValue.IsValid() { + return "id" + } + + scopeTyp := indirectValue.Type() + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue + } + + // if primaryKey tag found, return column name + if fieldStruct.Tag.Get("primaryKey") != "" { + return toSnake(fieldStruct.Name) + } + } + + //If primaryKey tag not found, fallback to id + return "id" +}