diff --git a/callbacks/create.go b/callbacks/create.go index c00a0a73..e8c4d532 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -92,12 +92,12 @@ func Create(config *Config) func(db *gorm.DB) { } } } else { - db.AddError(err) + db.AddErrorFromDB(err) } } } } else { - db.AddError(err) + db.AddErrorFromDB(err) } } } @@ -187,14 +187,14 @@ func CreateWithReturning(db *gorm.DB) { } } } else { - db.AddError(err) + db.AddErrorFromDB(err) } } } else if !db.DryRun && db.Error == nil { if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { db.RowsAffected, _ = result.RowsAffected() } else { - db.AddError(err) + db.AddErrorFromDB(err) } } } diff --git a/callbacks/raw.go b/callbacks/raw.go index d594ab39..b4190a65 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -8,7 +8,7 @@ func RawExec(db *gorm.DB) { if db.Error == nil && !db.DryRun { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { - db.AddError(err) + db.AddErrorFromDB(err) } else { db.RowsAffected, _ = result.RowsAffected() } diff --git a/callbacks/update.go b/callbacks/update.go index 46f59157..c8557a77 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -80,7 +80,7 @@ func Update(db *gorm.DB) { if err == nil { db.RowsAffected, _ = result.RowsAffected() } else { - db.AddError(err) + db.AddErrorFromDB(err) } } } diff --git a/errors.go b/errors.go index 08755083..34cab7e0 100644 --- a/errors.go +++ b/errors.go @@ -2,6 +2,8 @@ package gorm import ( "errors" + "fmt" + "strings" ) var ( @@ -32,3 +34,21 @@ var ( // ErrDryRunModeUnsupported dry run mode unsupported ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") ) + +// ErrUniqueConstraint unique constraint error +type ErrUniqueConstraint struct { + ConstraintName string + Columns []string +} + +func (e *ErrUniqueConstraint) Error() string { + if len(e.ConstraintName) > 0 { + return fmt.Sprintf("unique constraint '%s' error", e.ConstraintName) + } + + if len(e.Columns) > 0 { + return fmt.Sprintf("unique constraint on columns '%s' error", strings.Join(e.Columns, ", ")) + } + + return "unique constraint error" +} diff --git a/gorm.go b/gorm.go index 8efd8a73..d4692ab1 100644 --- a/gorm.go +++ b/gorm.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "sync" "time" @@ -252,6 +253,49 @@ func (db *DB) AddError(err error) error { return db.Error } +// AddErrorFromDB add error from database +func (db *DB) AddErrorFromDB(err error) error { + msg := err.Error() + // === Unique constraint errors + if strings.HasPrefix(msg, "UNIQUE constraint failed:") { + // SQLite3 + cols := strings.Split(msg[25:], ",") + for i := range cols { + cols[i] = strings.TrimSpace(cols[i]) + } + err = &ErrUniqueConstraint{ + Columns: cols, + } + } else if strings.HasPrefix(msg, "Error 1062: Duplicate entry") { + // MySQL + constr := strings.Trim(strings.TrimSpace(msg[strings.Index(msg, "for key")+7:]), "'") + p := strings.Split(constr, ".") + if len(p) == 1 { + constr = p[0] + } else { + constr = p[1] + } + + err = &ErrUniqueConstraint{ + ConstraintName: constr, + } + } else if strings.HasPrefix(msg, "ERROR: duplicate key value violates unique constraint") { + // PostgreSQL + constr := "" + i := strings.Index(msg, `"`) + j := strings.LastIndex(msg, `"`) + if i != -1 && j != i { + constr = msg[i+1 : j] + } + + err = &ErrUniqueConstraint{ + ConstraintName: constr, + } + } + + return db.AddError(err) +} + // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool diff --git a/schema/model_test.go b/schema/model_test.go index 1f2b0948..d9aed115 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -23,6 +23,7 @@ type User struct { Team []*User `gorm:"foreignkey:ManagerID"` Languages []*tests.Language `gorm:"many2many:UserSpeak"` Friends []*User `gorm:"many2many:user_friends"` + Contacts []*tests.Contact Active *bool } diff --git a/schema/schema_test.go b/schema/schema_test.go index 6ca5b269..4047303f 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -108,6 +108,10 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { }}, References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, }, + { + Name: "Contacts", Type: schema.HasMany, Schema: "User", FieldSchema: "Contact", + References: []Reference{{"ID", "User", "UserID", "Contact", "", true}}, + }, } for _, relation := range relations { diff --git a/tests/create_unique_test.go b/tests/create_unique_test.go new file mode 100644 index 00000000..8117a581 --- /dev/null +++ b/tests/create_unique_test.go @@ -0,0 +1,43 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCreateUniqueConstraint(t *testing.T) { + user1 := GetUser("create-unique-constraint", Config{}) + if err := DB.Create(user1).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user1Contact := &Contact{UserID: &user1.ID, Email: "create-unique-constraint@email"} + if err := DB.Create(user1Contact).Error; err != nil { + t.Fatalf("errors happened when create cotract: %v", err) + } + + user2 := GetUser("create-unique-constraint2", Config{}) + if err := DB.Create(user2).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user2Contact := &Contact{UserID: &user2.ID, Email: "create-unique-constraint@email"} + err := DB.Create(user2Contact).Error + if err == nil { + t.Fatal("should return unique constraint error") + } + e, ok := err.(*gorm.ErrUniqueConstraint) + if !ok { + t.Fatalf("should return unique constraint error, got err %v", err) + } + + if len(e.ConstraintName) > 0 { + AssertEqual(t, e.ConstraintName, "idx_email") + } + + if len(e.Columns) > 0 { + AssertEqual(t, e.Columns[0], "contacts.email") + } +} diff --git a/tests/helper_test.go b/tests/helper_test.go index eee34e99..c4c4ba79 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -19,6 +19,7 @@ type Config struct { Team int Languages int Friends int + Contacts int } func GetUser(name string, config Config) *User { @@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User { user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } + for i := 0; i < config.Contacts; i++ { + user.Contacts = append(user.Contacts, &Contact{Email: name + "_" + strconv.Itoa(i+1) + "@email"}) + } + return &user } @@ -87,6 +92,19 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) { } } +func CheckContact(t *testing.T, contact Contact, expect Contact) { + if contact.ID != 0 { + var newContact Contact + if err := DB.Where("id = ?", contact.ID).First(&newContact).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newContact, contact, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Email") + } + } + + AssertObjEqual(t, contact, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Email") +} + func CheckUser(t *testing.T, user User, expect User) { if user.ID != 0 { var newUser User @@ -227,4 +245,26 @@ func CheckUser(t *testing.T, user User, expect User) { AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } }) + + t.Run("Contacts", func(t *testing.T) { + if len(user.Contacts) != len(expect.Contacts) { + t.Fatalf("contacts should equal, expect: %v, got %v", len(expect.Contacts), len(user.Contacts)) + } + + sort.Slice(user.Contacts, func(i, j int) bool { + return user.Contacts[i].ID > user.Contacts[j].ID + }) + + sort.Slice(expect.Contacts, func(i, j int) bool { + return expect.Contacts[i].ID > expect.Contacts[j].ID + }) + + for idx, contact := range user.Contacts { + if contact == nil || expect.Contacts[idx] == nil { + t.Errorf("contact#%v should equal, expect: %v, got %v", idx, expect.Contacts[idx], contact) + } else { + CheckContact(t, *contact, *expect.Contacts[idx]) + } + } + }) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 4cc8a7c3..becf8f06 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -11,7 +11,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Contact{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/tests/tests_test.go b/tests/tests_test.go index cb73d267..efdea149 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -87,7 +87,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Contact{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..25d2c5a6 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -26,6 +26,7 @@ type User struct { Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak;"` Friends []*User `gorm:"many2many:user_friends;"` + Contacts []*Contact Active bool } @@ -58,3 +59,9 @@ type Language struct { Code string `gorm:"primarykey"` Name string } + +type Contact struct { + gorm.Model + UserID *uint + Email string `gorm:"index:idx_email,unique"` +}