diff --git a/clause/where.go b/clause/where.go index f7cd3318..a0f4598d 100644 --- a/clause/where.go +++ b/clause/where.go @@ -128,7 +128,7 @@ func (not NotConditions) Build(builder Builder) { if negationBuilder, ok := c.(NegationExpressionBuilder); ok { negationBuilder.NegationBuild(builder) } else { - builder.WriteString(" NOT ") + builder.WriteString("NOT ") c.Build(builder) } } diff --git a/statement.go b/statement.go index e65a064f..c03f6f88 100644 --- a/statement.go +++ b/statement.go @@ -265,7 +265,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } case map[string]interface{}: for i, j := range v { - conds = append(conds, clause.Eq{Column: i, Value: j}) + reflectValue := reflect.Indirect(reflect.ValueOf(j)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + conds = append(conds, clause.IN{Column: i, Values: values}) + default: + conds = append(conds, clause.Eq{Column: i, Value: j}) + } } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) @@ -299,6 +310,21 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } } } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return + } + } + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } } diff --git a/tests/create_test.go b/tests/create_test.go index 75059f18..46cc06c6 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -307,7 +307,7 @@ func TestCreateWithNoGORMPrimaryKey(t *testing.T) { func TestSelectWithCreate(t *testing.T) { user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) - DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "UpdatedAt", "Age", "Active").Create(&user) + DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) var user2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) diff --git a/tests/query_test.go b/tests/query_test.go index 594fc268..c9eb5903 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -179,6 +179,45 @@ func TestFillSmallerStruct(t *testing.T) { } } +func TestNot(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Not(map[string]interface{}{"name": "jinzhu"}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu1").Not("name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ AND NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}),