Add support for excluding a field from an index in the model struct.
This commit is contained in:
		
							parent
							
								
									adf9b80fb7
								
							
						
					
					
						commit
						8cf98c0598
					
				@ -27,7 +27,7 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
 | 
				
			|||||||
	if sqlType == "" {
 | 
						if sqlType == "" {
 | 
				
			||||||
		switch dataValue.Kind() {
 | 
							switch dataValue.Kind() {
 | 
				
			||||||
		case reflect.Bool:
 | 
							case reflect.Bool:
 | 
				
			||||||
			sqlType = "bool"
 | 
								sqlType = "numeric"
 | 
				
			||||||
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
			if field.IsPrimaryKey {
 | 
								if field.IsPrimaryKey {
 | 
				
			||||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
									field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
				
			||||||
 | 
				
			|||||||
@ -5,6 +5,7 @@ import (
 | 
				
			|||||||
	"database/sql/driver"
 | 
						"database/sql/driver"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@ -436,3 +437,50 @@ func TestMultipleIndexes(t *testing.T) {
 | 
				
			|||||||
		t.Error("MultipleIndexes unique index failed")
 | 
							t.Error("MultipleIndexes unique index failed")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ExcludedColumnIndex struct {
 | 
				
			||||||
 | 
						ID      int64
 | 
				
			||||||
 | 
						Email   string `sql:"unique_index:uix_excluded_column_index_email"`
 | 
				
			||||||
 | 
						Deleted bool   `sql:"unique_index:!uix_excluded_column_index_email"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestConditionalIndexExcludesColumn(t *testing.T) {
 | 
				
			||||||
 | 
						// This behavior is only supported for DBMSes that support partial indices
 | 
				
			||||||
 | 
						// (i.e., not MySQL).
 | 
				
			||||||
 | 
						dialect := os.Getenv("GORM_DIALECT")
 | 
				
			||||||
 | 
						switch dialect {
 | 
				
			||||||
 | 
						case "", "sqlite", "postgres", "mssql":
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.DropTableIfExists(&ExcludedColumnIndex{}).Error; err != nil {
 | 
				
			||||||
 | 
							fmt.Printf("Got error when try to delete table excluded_column_index, %+v\n", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.AutoMigrate(&ExcludedColumnIndex{}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Failed to migrate: %+v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.Save(&ExcludedColumnIndex{Email: "impl@example.com"}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Unexpected error saving first entry: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.Save(&ExcludedColumnIndex{Email: "impl@example.com"}).Error; err == nil {
 | 
				
			||||||
 | 
							t.Error("Unique index was not created")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var u ExcludedColumnIndex
 | 
				
			||||||
 | 
						if err := DB.First(&u).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Enexpected error retrieving first entry: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						u.Deleted = true
 | 
				
			||||||
 | 
						if err := DB.Save(&u).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Unexpected error saving first entry: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.Save(&ExcludedColumnIndex{Email: "impl@example.com"}).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Conditional index failed: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										95
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										95
									
								
								scope.go
									
									
									
									
									
								
							@ -1185,40 +1185,81 @@ func (scope *Scope) autoMigrate() *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type derivedIndex struct {
 | 
				
			||||||
 | 
						columns []string
 | 
				
			||||||
 | 
						q       *DB
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) autoIndex() *Scope {
 | 
					func (scope *Scope) autoIndex() *Scope {
 | 
				
			||||||
	var indexes = map[string][]string{}
 | 
						indexes := map[string]map[string]*derivedIndex{
 | 
				
			||||||
	var uniqueIndexes = map[string][]string{}
 | 
							"INDEX":        make(map[string]*derivedIndex),
 | 
				
			||||||
 | 
							"UNIQUE_INDEX": make(map[string]*derivedIndex),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						derive := func(field *StructField, tag, prefix string) {
 | 
				
			||||||
 | 
							if name, ok := field.TagSettings[tag]; ok {
 | 
				
			||||||
 | 
								names := strings.Split(name, ",")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for _, name := range names {
 | 
				
			||||||
 | 
									exclude := false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if name == tag || name == "" {
 | 
				
			||||||
 | 
										name = fmt.Sprintf("%s_%v_%v", prefix, scope.TableName(), field.DBName)
 | 
				
			||||||
 | 
									} else if name[0] == '!' {
 | 
				
			||||||
 | 
										exclude = true
 | 
				
			||||||
 | 
										name = name[1:]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									idx, ok := indexes[tag][name]
 | 
				
			||||||
 | 
									if !ok {
 | 
				
			||||||
 | 
										idx = &derivedIndex{q: scope.NewDB().Model(scope.Value)}
 | 
				
			||||||
 | 
										indexes[tag][name] = idx
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if exclude {
 | 
				
			||||||
 | 
										// We can't just bind this; most (all?) DBMSes don't seem
 | 
				
			||||||
 | 
										// to support parameterizing partial indices.
 | 
				
			||||||
 | 
										f, _ := scope.FieldByName(field.Name)
 | 
				
			||||||
 | 
										v, _ := driver.DefaultParameterConverter.ConvertValue(reflect.Zero(f.Field.Type()).Interface())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										// Possibilities are limited:
 | 
				
			||||||
 | 
										// https://golang.org/pkg/database/sql/driver/#Value
 | 
				
			||||||
 | 
										cond := "IS NULL"
 | 
				
			||||||
 | 
										switch value := v.(type) {
 | 
				
			||||||
 | 
										case int64, float64:
 | 
				
			||||||
 | 
											cond = "= 0"
 | 
				
			||||||
 | 
										case bool:
 | 
				
			||||||
 | 
											switch scope.Dialect().DataTypeOf(field) {
 | 
				
			||||||
 | 
											case "bool", "boolean":
 | 
				
			||||||
 | 
												cond = "= false"
 | 
				
			||||||
 | 
											default:
 | 
				
			||||||
 | 
												cond = "= 0"
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										case []byte, string:
 | 
				
			||||||
 | 
											cond = "= ''"
 | 
				
			||||||
 | 
										case time.Time:
 | 
				
			||||||
 | 
											cond = fmt.Sprintf("= '%s'", value.Format(time.RFC3339))
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										idx.q = idx.q.Where(fmt.Sprintf("%s %s", scope.Quote(field.DBName), cond))
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										idx.columns = append(idx.columns, field.DBName)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, field := range scope.GetStructFields() {
 | 
						for _, field := range scope.GetStructFields() {
 | 
				
			||||||
		if name, ok := field.TagSettings["INDEX"]; ok {
 | 
							derive(field, "INDEX", "idx")
 | 
				
			||||||
			names := strings.Split(name, ",")
 | 
							derive(field, "UNIQUE_INDEX", "uix")
 | 
				
			||||||
 | 
					 | 
				
			||||||
			for _, name := range names {
 | 
					 | 
				
			||||||
				if name == "INDEX" || name == "" {
 | 
					 | 
				
			||||||
					name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				indexes[name] = append(indexes[name], field.DBName)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
 | 
					 | 
				
			||||||
			names := strings.Split(name, ",")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for _, name := range names {
 | 
					 | 
				
			||||||
				if name == "UNIQUE_INDEX" || name == "" {
 | 
					 | 
				
			||||||
					name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for name, columns := range indexes {
 | 
						for name, idx := range indexes["INDEX"] {
 | 
				
			||||||
		scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
 | 
							idx.q.AddIndex(name, idx.columns...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for name, columns := range uniqueIndexes {
 | 
						for name, idx := range indexes["UNIQUE_INDEX"] {
 | 
				
			||||||
		scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
 | 
							idx.q.AddUniqueIndex(name, idx.columns...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return scope
 | 
						return scope
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user