diff --git a/README.md b/README.md index 8a904a3e..5059f71d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ go get github.com/jinzhu/gorm ## Conventions * Table name is the plural of struct name's snake case. - Disable pluralization with `db.SingularTable(true)`, or [specify your table name](#specify-table-name) + Disable pluralization with `db.SingularTable(true)`, or [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName) * Column name is the snake case of field's name. * Use `Id int64` field as primary key. * Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`. @@ -47,6 +47,19 @@ db.First(&user) DB.Save(&User{Name: "xxx"}) // table "users" ``` +## Existing schema + +If you have and existing database schema and some of your tables does not follow the conventions, (and you can't rename your table names), please use: [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName). + +If your primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key. + +```go +type Animal struct { // animals + AnimalId int64 `primaryKey:"yes"` + Birthday time.Time + Age int64 +``` + # Getting Started ```go diff --git a/main_test.go b/main_test.go index 5e8e3070..5deb7822 100644 --- a/main_test.go +++ b/main_test.go @@ -7,9 +7,9 @@ import ( "fmt" _ "github.com/go-sql-driver/mysql" - "github.com/jinzhu/gorm" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/nerdzeu/gorm" "os" "reflect" @@ -127,6 +127,13 @@ type Product struct { AfterDeleteCallTimes int64 } +type Animal struct { + Counter int64 `primaryKey:"yes"` + Name string + CreatedAt time.Time + UpdatedAt time.Time +} + var ( db gorm.DB t1, t2, t3, t4, t5 time.Time @@ -170,6 +177,11 @@ func init() { db.Exec("drop table credit_cards") db.Exec("drop table roles") db.Exec("drop table companies") + db.Exec("drop table animals") + + if err = db.CreateTable(&Animal{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } if err = db.CreateTable(&User{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) @@ -210,6 +222,11 @@ func init() { db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4}) db.Save(&User{Name: "5", Age: 26, Birthday: t4}) + + db.Save(&Animal{Name: "First"}) + db.Save(&Animal{Name: "Amazing"}) + db.Save(&Animal{Name: "Horse"}) + db.Save(&Animal{Name: "Last"}) } func TestFirstAndLast(t *testing.T) { @@ -230,6 +247,24 @@ func TestFirstAndLast(t *testing.T) { } } +func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) { + var animal1, animal2, animal3, animal4 Animal + db.First(&animal1) + db.Order("counter").Find(&animal2) + + db.Last(&animal3) + db.Order("counter desc").Find(&animal4) + if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { + t.Errorf("First and Last should works correctly") + } + + var animals []Animal + db.First(&animals) + if len(animals) != 1 { + t.Errorf("Find first record as map") + } +} + func TestSaveCustomType(t *testing.T) { var user, user1 User db.First(&user, "name = ?", "1") @@ -935,17 +970,31 @@ func TestSetTableDirectly(t *testing.T) { func TestUpdate(t *testing.T) { product1 := Product{Code: "123"} product2 := Product{Code: "234"} + animal1 := Animal{Name: "Ferdinand"} + animal2 := Animal{Name: "nerdz"} + db.Save(&product1).Save(&product2).Update("code", "456") if product2.Code != "456" { t.Errorf("Record should be updated with update attributes") } + db.Save(&animal1).Save(&animal2).Update("name", "Francis") + + if animal2.Name != "Francis" { + t.Errorf("Record should be updated with update attributes") + } + db.First(&product1, product1.Id) db.First(&product2, product2.Id) updated_at1 := product1.UpdatedAt updated_at2 := product2.UpdatedAt + db.First(&animal1, animal1.Counter) + db.First(&animal2, animal2.Counter) + animalUpdated_at1 := animal1.UpdatedAt + animalUpdated_at2 := animal2.UpdatedAt + var product3 Product db.First(&product3, product2.Id).Update("code", "456") if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { @@ -964,6 +1013,25 @@ func TestUpdate(t *testing.T) { t.Errorf("Product 234 should be changed to 456") } + var animal3 Animal + db.First(&animal3, animal2.Counter).Update("Name", "Robert") + + if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should not be updated if nothing changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil { + t.Errorf("Animal 'Ferdinand' should not be updated") + } + + if db.First(&Animal{}, "name = 'nerdz'").Error == nil { + t.Errorf("Animal 'nerdz' should be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'nerdz' should be changed to 'Robert'") + } + db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") var product4 Product @@ -991,6 +1059,34 @@ func TestUpdate(t *testing.T) { if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { t.Error("No error should raise when update_column with CamelCase") } + + db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz") + + var animal4 Animal + db.First(&animal4, animal1.Counter) + if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("animalUpdated_at should be updated if something changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil { + t.Errorf("Animal 'Fredinand' should be changed to 'Franz'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'Robert' should not be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Franz'").Error != nil { + t.Errorf("Product 'nerdz' should be changed to 'Franz'") + } + + if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update with CamelCase") + } + + if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update_column with CamelCase") + } } func TestUpdates(t *testing.T) { diff --git a/scope.go b/scope.go index 40697f9a..543c4cde 100644 --- a/scope.go +++ b/scope.go @@ -13,13 +13,14 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - _values map[string]interface{} - skipLeft bool + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} + skipLeft bool + primaryKey string } // NewScope create scope for callbacks, including DB's search information @@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool { // PrimaryKey get the primary key's column name func (scope *Scope) PrimaryKey() string { - return "id" + if scope.primaryKey != "" { + return scope.primaryKey + } + + scope.primaryKey = scope.getPrimaryKey() + return scope.primaryKey } // PrimaryKeyZero check the primary key is blank or not @@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field { value := indirectValue.FieldByName(fieldStruct.Name) field.Value = value.Interface() field.IsBlank = isBlank(value) - field.isPrimaryKey = scope.PrimaryKey() == field.DBName + + // Search for primary key tag identifier + field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != "" + + if field.isPrimaryKey { + scope.primaryKey = field.DBName + } if scope.db != nil { field.Tag = fieldStruct.Tag diff --git a/scope_private.go b/scope_private.go index 72f631cc..d9f16438 100644 --- a/scope_private.go +++ b/scope_private.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "go/ast" "reflect" "regexp" "strconv" @@ -472,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope { } return scope } + +func (scope *Scope) getPrimaryKey() string { + var indirectValue reflect.Value + + indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value)) + + if indirectValue.Kind() == reflect.Slice { + indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() + } + + if !indirectValue.IsValid() { + return "id" + } + + scopeTyp := indirectValue.Type() + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue + } + + // if primaryKey tag found, return column name + if fieldStruct.Tag.Get("primaryKey") != "" { + return toSnake(fieldStruct.Name) + } + } + + //If primaryKey tag not found, fallback to id + return "id" +}