Add getPrimaryKey: analize the tag string in the struct fields and find the one marked as primaryKey
Add primaryKey field to scope and uses getPrimaryKey to find the one marked in that way, if present. Otherwise fallback to id Format code with gofmt Fixes getPrimaryKey for non struct type Add tests add Tests for update a struct
This commit is contained in:
parent
5574d47f1f
commit
fdd9a52800
98
main_test.go
98
main_test.go
@ -7,9 +7,9 @@ import (
|
||||
"fmt"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/nerdzeu/gorm"
|
||||
|
||||
"os"
|
||||
"reflect"
|
||||
@ -127,6 +127,13 @@ type Product struct {
|
||||
AfterDeleteCallTimes int64
|
||||
}
|
||||
|
||||
type Animal struct {
|
||||
Counter int64 `primaryKey:"yes"`
|
||||
Name string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
db gorm.DB
|
||||
t1, t2, t3, t4, t5 time.Time
|
||||
@ -170,6 +177,11 @@ func init() {
|
||||
db.Exec("drop table credit_cards")
|
||||
db.Exec("drop table roles")
|
||||
db.Exec("drop table companies")
|
||||
db.Exec("drop table animals")
|
||||
|
||||
if err = db.CreateTable(&Animal{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
if err = db.CreateTable(&User{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
@ -210,6 +222,11 @@ func init() {
|
||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||
db.Save(&User{Name: "5", Age: 26, Birthday: t4})
|
||||
|
||||
db.Save(&Animal{Name: "First"})
|
||||
db.Save(&Animal{Name: "Amazing"})
|
||||
db.Save(&Animal{Name: "Horse"})
|
||||
db.Save(&Animal{Name: "Last"})
|
||||
}
|
||||
|
||||
func TestFirstAndLast(t *testing.T) {
|
||||
@ -230,6 +247,24 @@ func TestFirstAndLast(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) {
|
||||
var animal1, animal2, animal3, animal4 Animal
|
||||
db.First(&animal1)
|
||||
db.Order("counter").Find(&animal2)
|
||||
|
||||
db.Last(&animal3)
|
||||
db.Order("counter desc").Find(&animal4)
|
||||
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
|
||||
t.Errorf("First and Last should works correctly")
|
||||
}
|
||||
|
||||
var animals []Animal
|
||||
db.First(&animals)
|
||||
if len(animals) != 1 {
|
||||
t.Errorf("Find first record as map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCustomType(t *testing.T) {
|
||||
var user, user1 User
|
||||
db.First(&user, "name = ?", "1")
|
||||
@ -935,17 +970,31 @@ func TestSetTableDirectly(t *testing.T) {
|
||||
func TestUpdate(t *testing.T) {
|
||||
product1 := Product{Code: "123"}
|
||||
product2 := Product{Code: "234"}
|
||||
animal1 := Animal{Name: "Ferdinand"}
|
||||
animal2 := Animal{Name: "nerdz"}
|
||||
|
||||
db.Save(&product1).Save(&product2).Update("code", "456")
|
||||
|
||||
if product2.Code != "456" {
|
||||
t.Errorf("Record should be updated with update attributes")
|
||||
}
|
||||
|
||||
db.Save(&animal1).Save(&animal2).Update("name", "Francis")
|
||||
|
||||
if animal2.Name != "Francis" {
|
||||
t.Errorf("Record should be updated with update attributes")
|
||||
}
|
||||
|
||||
db.First(&product1, product1.Id)
|
||||
db.First(&product2, product2.Id)
|
||||
updated_at1 := product1.UpdatedAt
|
||||
updated_at2 := product2.UpdatedAt
|
||||
|
||||
db.First(&animal1, animal1.Counter)
|
||||
db.First(&animal2, animal2.Counter)
|
||||
animalUpdated_at1 := animal1.UpdatedAt
|
||||
animalUpdated_at2 := animal2.UpdatedAt
|
||||
|
||||
var product3 Product
|
||||
db.First(&product3, product2.Id).Update("code", "456")
|
||||
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
@ -964,6 +1013,25 @@ func TestUpdate(t *testing.T) {
|
||||
t.Errorf("Product 234 should be changed to 456")
|
||||
}
|
||||
|
||||
var animal3 Animal
|
||||
db.First(&animal3, animal2.Counter).Update("Name", "Robert")
|
||||
|
||||
if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("updated_at should not be updated if nothing changed")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil {
|
||||
t.Errorf("Animal 'Ferdinand' should not be updated")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'nerdz'").Error == nil {
|
||||
t.Errorf("Animal 'nerdz' should be changed to 'Francis'")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
|
||||
t.Errorf("Animal 'nerdz' should be changed to 'Robert'")
|
||||
}
|
||||
|
||||
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
||||
|
||||
var product4 Product
|
||||
@ -991,6 +1059,34 @@ func TestUpdate(t *testing.T) {
|
||||
if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||
t.Error("No error should raise when update_column with CamelCase")
|
||||
}
|
||||
|
||||
db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz")
|
||||
|
||||
var animal4 Animal
|
||||
db.First(&animal4, animal1.Counter)
|
||||
if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("animalUpdated_at should be updated if something changed")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil {
|
||||
t.Errorf("Animal 'Fredinand' should be changed to 'Franz'")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
|
||||
t.Errorf("Animal 'Robert' should not be changed to 'Francis'")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Franz'").Error != nil {
|
||||
t.Errorf("Product 'nerdz' should be changed to 'Franz'")
|
||||
}
|
||||
|
||||
if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||
t.Error("No error should raise when update with CamelCase")
|
||||
}
|
||||
|
||||
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||
t.Error("No error should raise when update_column with CamelCase")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdates(t *testing.T) {
|
||||
|
30
scope.go
30
scope.go
@ -13,13 +13,14 @@ import (
|
||||
)
|
||||
|
||||
type Scope struct {
|
||||
Value interface{}
|
||||
Search *search
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
db *DB
|
||||
_values map[string]interface{}
|
||||
skipLeft bool
|
||||
Value interface{}
|
||||
Search *search
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
db *DB
|
||||
_values map[string]interface{}
|
||||
skipLeft bool
|
||||
primaryKey string
|
||||
}
|
||||
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool {
|
||||
|
||||
// PrimaryKey get the primary key's column name
|
||||
func (scope *Scope) PrimaryKey() string {
|
||||
return "id"
|
||||
if scope.primaryKey != "" {
|
||||
return scope.primaryKey
|
||||
}
|
||||
|
||||
scope.primaryKey = scope.getPrimaryKey()
|
||||
return scope.primaryKey
|
||||
}
|
||||
|
||||
// PrimaryKeyZero check the primary key is blank or not
|
||||
@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field {
|
||||
value := indirectValue.FieldByName(fieldStruct.Name)
|
||||
field.Value = value.Interface()
|
||||
field.IsBlank = isBlank(value)
|
||||
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
|
||||
|
||||
// Search for primary key tag identifier
|
||||
field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != ""
|
||||
|
||||
if field.isPrimaryKey {
|
||||
scope.primaryKey = field.DBName
|
||||
}
|
||||
|
||||
if scope.db != nil {
|
||||
field.Tag = fieldStruct.Tag
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@ -472,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope {
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) getPrimaryKey() string {
|
||||
var indirectValue reflect.Value
|
||||
|
||||
indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
|
||||
if indirectValue.Kind() == reflect.Slice {
|
||||
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if !indirectValue.IsValid() {
|
||||
return "id"
|
||||
}
|
||||
|
||||
scopeTyp := indirectValue.Type()
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
fieldStruct := scopeTyp.Field(i)
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// if primaryKey tag found, return column name
|
||||
if fieldStruct.Tag.Get("primaryKey") != "" {
|
||||
return toSnake(fieldStruct.Name)
|
||||
}
|
||||
}
|
||||
|
||||
//If primaryKey tag not found, fallback to id
|
||||
return "id"
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user