diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 33f4aa50..727609ab 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -27,7 +27,7 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: - sqlType = "bool" + sqlType = "numeric" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if field.IsPrimaryKey { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" diff --git a/migration_test.go b/migration_test.go index 95c2c571..94b3e30a 100644 --- a/migration_test.go +++ b/migration_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "os" "reflect" "testing" "time" @@ -436,3 +437,50 @@ func TestMultipleIndexes(t *testing.T) { t.Error("MultipleIndexes unique index failed") } } + +type ExcludedColumnIndex struct { + ID int64 + Email string `sql:"unique_index:uix_excluded_column_index_email"` + Deleted bool `sql:"unique_index:!uix_excluded_column_index_email"` +} + +func TestConditionalIndexExcludesColumn(t *testing.T) { + // This behavior is only supported for DBMSes that support partial indices + // (i.e., not MySQL). + dialect := os.Getenv("GORM_DIALECT") + switch dialect { + case "", "sqlite", "postgres", "mssql": + default: + return + } + + if err := DB.DropTableIfExists(&ExcludedColumnIndex{}).Error; err != nil { + fmt.Printf("Got error when try to delete table excluded_column_index, %+v\n", err) + } + + if err := DB.AutoMigrate(&ExcludedColumnIndex{}).Error; err != nil { + t.Errorf("Failed to migrate: %+v", err) + } + + if err := DB.Save(&ExcludedColumnIndex{Email: "impl@example.com"}).Error; err != nil { + t.Errorf("Unexpected error saving first entry: %v", err) + } + + if err := DB.Save(&ExcludedColumnIndex{Email: "impl@example.com"}).Error; err == nil { + t.Error("Unique index was not created") + } + + var u ExcludedColumnIndex + if err := DB.First(&u).Error; err != nil { + t.Errorf("Enexpected error retrieving first entry: %v", err) + } + + u.Deleted = true + if err := DB.Save(&u).Error; err != nil { + t.Errorf("Unexpected error saving first entry: %v", err) + } + + if err := DB.Save(&ExcludedColumnIndex{Email: "impl@example.com"}).Error; err != nil { + t.Errorf("Conditional index failed: %v", err) + } +} diff --git a/scope.go b/scope.go index 45f7185f..e31a26b1 100644 --- a/scope.go +++ b/scope.go @@ -1185,40 +1185,81 @@ func (scope *Scope) autoMigrate() *Scope { return scope } +type derivedIndex struct { + columns []string + q *DB +} + func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} + indexes := map[string]map[string]*derivedIndex{ + "INDEX": make(map[string]*derivedIndex), + "UNIQUE_INDEX": make(map[string]*derivedIndex), + } + + derive := func(field *StructField, tag, prefix string) { + if name, ok := field.TagSettings[tag]; ok { + names := strings.Split(name, ",") + + for _, name := range names { + exclude := false + + if name == tag || name == "" { + name = fmt.Sprintf("%s_%v_%v", prefix, scope.TableName(), field.DBName) + } else if name[0] == '!' { + exclude = true + name = name[1:] + } + + idx, ok := indexes[tag][name] + if !ok { + idx = &derivedIndex{q: scope.NewDB().Model(scope.Value)} + indexes[tag][name] = idx + } + + if exclude { + // We can't just bind this; most (all?) DBMSes don't seem + // to support parameterizing partial indices. + f, _ := scope.FieldByName(field.Name) + v, _ := driver.DefaultParameterConverter.ConvertValue(reflect.Zero(f.Field.Type()).Interface()) + + // Possibilities are limited: + // https://golang.org/pkg/database/sql/driver/#Value + cond := "IS NULL" + switch value := v.(type) { + case int64, float64: + cond = "= 0" + case bool: + switch scope.Dialect().DataTypeOf(field) { + case "bool", "boolean": + cond = "= false" + default: + cond = "= 0" + } + case []byte, string: + cond = "= ''" + case time.Time: + cond = fmt.Sprintf("= '%s'", value.Format(time.RFC3339)) + } + + idx.q = idx.q.Where(fmt.Sprintf("%s %s", scope.Quote(field.DBName), cond)) + } else { + idx.columns = append(idx.columns, field.DBName) + } + } + } + } for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) - } - indexes[name] = append(indexes[name], field.DBName) - } - } - - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) - } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) - } - } + derive(field, "INDEX", "idx") + derive(field, "UNIQUE_INDEX", "uix") } - for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + for name, idx := range indexes["INDEX"] { + idx.q.AddIndex(name, idx.columns...) } - for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + for name, idx := range indexes["UNIQUE_INDEX"] { + idx.q.AddUniqueIndex(name, idx.columns...) } return scope