Implement preload support
This commit is contained in:
		
							parent
							
								
									41697d58d3
								
							
						
					
					
						commit
						b549f9bb9a
					
				| @ -1,9 +1,196 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/clause" | ||||||
| 	"github.com/jinzhu/gorm/schema" | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | 	"github.com/jinzhu/gorm/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func preload(db *gorm.DB, preloadFields []string, rel *schema.Relationship) { | // getRelationsValue get relations's value from a reflect value
 | ||||||
|  | func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) { | ||||||
|  | 	for _, rel := range rels { | ||||||
|  | 		reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) | ||||||
|  | 
 | ||||||
|  | 		appendToResults := func(value reflect.Value) { | ||||||
|  | 			if _, isZero := rel.Field.ValueOf(value); !isZero { | ||||||
|  | 				result := reflect.Indirect(rel.Field.ReflectValueOf(value)) | ||||||
|  | 				switch result.Kind() { | ||||||
|  | 				case reflect.Struct: | ||||||
|  | 					reflectResults = reflect.Append(reflectResults, result) | ||||||
|  | 				case reflect.Slice, reflect.Array: | ||||||
|  | 					for i := 0; i < value.Len(); i++ { | ||||||
|  | 						reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		switch reflectValue.Kind() { | ||||||
|  | 		case reflect.Struct: | ||||||
|  | 			appendToResults(reflectValue) | ||||||
|  | 		case reflect.Slice: | ||||||
|  | 			for i := 0; i < reflectValue.Len(); i++ { | ||||||
|  | 				appendToResults(reflectValue.Index(i)) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		reflectValue = reflectResults | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) { | ||||||
|  | 	var ( | ||||||
|  | 		fieldValues = make([]reflect.Value, len(fields)) | ||||||
|  | 		results     = [][]interface{}{} | ||||||
|  | 		dataResults = map[string][]reflect.Value{} | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	switch reflectValue.Kind() { | ||||||
|  | 	case reflect.Struct: | ||||||
|  | 		results = [][]interface{}{make([]interface{}, len(fields))} | ||||||
|  | 
 | ||||||
|  | 		for idx, field := range fields { | ||||||
|  | 			fieldValues[idx] = field.ReflectValueOf(reflectValue) | ||||||
|  | 			results[0][idx] = fieldValues[idx].Interface() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} | ||||||
|  | 	case reflect.Slice, reflect.Array: | ||||||
|  | 		for i := 0; i < reflectValue.Len(); i++ { | ||||||
|  | 			for idx, field := range fields { | ||||||
|  | 				fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			dataKey := utils.ToStringKey(fieldValues...) | ||||||
|  | 			if _, ok := dataResults[dataKey]; !ok { | ||||||
|  | 				result := make([]interface{}, len(fieldValues)) | ||||||
|  | 				for idx, fieldValue := range fieldValues { | ||||||
|  | 					result[idx] = fieldValue.Interface() | ||||||
|  | 				} | ||||||
|  | 				results = append(results, result) | ||||||
|  | 
 | ||||||
|  | 				dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} | ||||||
|  | 			} else { | ||||||
|  | 				dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return dataResults, results | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value { | ||||||
|  | 	results := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) | ||||||
|  | 	queryValues := make([]interface{}, len(foreignValues)) | ||||||
|  | 	if len(foreignKeys) == 1 { | ||||||
|  | 		for idx, r := range foreignValues { | ||||||
|  | 			queryValues[idx] = r[0] | ||||||
|  | 		} | ||||||
|  | 		tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Addr().Interface()) | ||||||
|  | 	} else { | ||||||
|  | 		for idx, r := range foreignValues { | ||||||
|  | 			queryValues[idx] = r | ||||||
|  | 		} | ||||||
|  | 		tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Addr().Interface()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return results | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) { | ||||||
|  | 	var ( | ||||||
|  | 		reflectValue     = tx.Statement.ReflectValue | ||||||
|  | 		rel              = rels[len(rels)-1] | ||||||
|  | 		relForeignKeys   []string | ||||||
|  | 		relForeignFields []*schema.Field | ||||||
|  | 		foreignFields    []*schema.Field | ||||||
|  | 		foreignValues    [][]interface{} | ||||||
|  | 		identityMap      = map[string][]reflect.Value{} | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if len(rels) > 1 { | ||||||
|  | 		reflectValue = getRelationsValue(reflectValue, rels[:len(rels)]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if rel.JoinTable != nil { | ||||||
|  | 		var joinForeignFields, joinRelForeignFields []*schema.Field | ||||||
|  | 		var joinForeignKeys []string | ||||||
|  | 		for _, ref := range rel.References { | ||||||
|  | 			if ref.OwnPrimaryKey { | ||||||
|  | 				joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) | ||||||
|  | 				joinForeignFields = append(joinForeignFields, ref.ForeignKey) | ||||||
|  | 				foreignFields = append(foreignFields, ref.PrimaryKey) | ||||||
|  | 			} else if ref.PrimaryValue != "" { | ||||||
|  | 				tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) | ||||||
|  | 			} else { | ||||||
|  | 				joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) | ||||||
|  | 				relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) | ||||||
|  | 				relForeignFields = append(relForeignFields, ref.PrimaryKey) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) | ||||||
|  | 		joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues) | ||||||
|  | 
 | ||||||
|  | 		// convert join identity map to relation identity map
 | ||||||
|  | 		fieldValues := make([]reflect.Value, len(foreignFields)) | ||||||
|  | 		joinFieldValues := make([]reflect.Value, len(joinForeignFields)) | ||||||
|  | 		for i := 0; i < joinResults.Len(); i++ { | ||||||
|  | 			for idx, field := range foreignFields { | ||||||
|  | 				fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for idx, field := range joinForeignFields { | ||||||
|  | 				joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { | ||||||
|  | 				identityMap[utils.ToStringKey(joinFieldValues...)] = results | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		_, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields) | ||||||
|  | 	} else { | ||||||
|  | 		for _, ref := range rel.References { | ||||||
|  | 			if ref.OwnPrimaryKey { | ||||||
|  | 				relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||||
|  | 				relForeignFields = append(relForeignFields, ref.ForeignKey) | ||||||
|  | 				foreignFields = append(foreignFields, ref.PrimaryKey) | ||||||
|  | 			} else if ref.PrimaryValue != "" { | ||||||
|  | 				tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) | ||||||
|  | 			} else { | ||||||
|  | 				relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) | ||||||
|  | 				relForeignFields = append(relForeignFields, ref.PrimaryKey) | ||||||
|  | 				foreignFields = append(foreignFields, ref.ForeignKey) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues) | ||||||
|  | 
 | ||||||
|  | 	fieldValues := make([]reflect.Value, len(foreignFields)) | ||||||
|  | 	for i := 0; i < reflectResults.Len(); i++ { | ||||||
|  | 		for idx, field := range foreignFields { | ||||||
|  | 			fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { | ||||||
|  | 			reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) | ||||||
|  | 			switch reflectFieldValue.Kind() { | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) | ||||||
|  | 				rel.Field.Set(data, elem.Interface()) | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) | ||||||
|  | 				rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -25,6 +25,7 @@ func Query(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		// inline joins
 | ||||||
| 		if len(db.Statement.Joins) != 0 { | 		if len(db.Statement.Joins) != 0 { | ||||||
| 			joins := []clause.Join{} | 			joins := []clause.Join{} | ||||||
| 
 | 
 | ||||||
| @ -101,7 +102,6 @@ func Query(db *gorm.DB) { | |||||||
| func Preload(db *gorm.DB) { | func Preload(db *gorm.DB) { | ||||||
| 	if len(db.Statement.Preloads) > 0 { | 	if len(db.Statement.Preloads) > 0 { | ||||||
| 		preloadMap := map[string][]string{} | 		preloadMap := map[string][]string{} | ||||||
| 
 |  | ||||||
| 		for name := range db.Statement.Preloads { | 		for name := range db.Statement.Preloads { | ||||||
| 			preloadFields := strings.Split(name, ".") | 			preloadFields := strings.Split(name, ".") | ||||||
| 			for idx := range preloadFields { | 			for idx := range preloadFields { | ||||||
| @ -118,27 +118,22 @@ func Preload(db *gorm.DB) { | |||||||
| 		sort.Strings(preloadNames) | 		sort.Strings(preloadNames) | ||||||
| 
 | 
 | ||||||
| 		for _, name := range preloadNames { | 		for _, name := range preloadNames { | ||||||
| 			curSchema := db.Statement.Schema | 			var ( | ||||||
| 			preloadFields := preloadMap[name] | 				curSchema     = db.Statement.Schema | ||||||
|  | 				preloadFields = preloadMap[name] | ||||||
|  | 				rels          = make([]*schema.Relationship, len(preloadFields)) | ||||||
|  | 			) | ||||||
| 
 | 
 | ||||||
| 			for idx, preloadField := range preloadFields { | 			for idx, preloadField := range preloadFields { | ||||||
| 				if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { | 				if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { | ||||||
| 					if idx == len(preloadFields)-1 { | 					rels[idx] = rel | ||||||
| 						conds := db.Statement.Preloads[strings.Join(preloadFields[:idx+1], ".")] | 					curSchema = rel.FieldSchema | ||||||
| 
 |  | ||||||
| 						switch rel.Type { |  | ||||||
| 						case schema.HasOne: |  | ||||||
| 						case schema.HasMany: |  | ||||||
| 						case schema.BelongsTo: |  | ||||||
| 						case schema.Many2Many: |  | ||||||
| 						} |  | ||||||
| 					} else { |  | ||||||
| 						curSchema = rel.FieldSchema |  | ||||||
| 					} |  | ||||||
| 				} else { | 				} else { | ||||||
| 					db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) | 					db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | 
 | ||||||
|  | 			preload(db.Session(&gorm.Session{}), rels, db.Statement.Preloads[name]) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -95,6 +95,15 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { | |||||||
| 		} | 		} | ||||||
| 	case string: | 	case string: | ||||||
| 		stmt.DB.Dialector.QuoteTo(writer, v) | 		stmt.DB.Dialector.QuoteTo(writer, v) | ||||||
|  | 	case []string: | ||||||
|  | 		writer.WriteByte('(') | ||||||
|  | 		for idx, d := range v { | ||||||
|  | 			if idx != 0 { | ||||||
|  | 				writer.WriteString(",") | ||||||
|  | 			} | ||||||
|  | 			stmt.DB.Dialector.QuoteTo(writer, d) | ||||||
|  | 		} | ||||||
|  | 		writer.WriteByte(')') | ||||||
| 	default: | 	default: | ||||||
| 		stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) | 		stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"runtime" | 	"runtime" | ||||||
|  | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"unicode" | 	"unicode" | ||||||
| ) | ) | ||||||
| @ -38,3 +39,24 @@ func CheckTruth(val interface{}) bool { | |||||||
| 
 | 
 | ||||||
| 	return !reflect.ValueOf(val).IsZero() | 	return !reflect.ValueOf(val).IsZero() | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func ToStringKey(values ...reflect.Value) string { | ||||||
|  | 	results := make([]string, len(values)) | ||||||
|  | 
 | ||||||
|  | 	for idx, value := range values { | ||||||
|  | 		rv := reflect.Indirect(value).Interface() | ||||||
|  | 
 | ||||||
|  | 		switch v := rv.(type) { | ||||||
|  | 		case string: | ||||||
|  | 			results[idx] = v | ||||||
|  | 		case []byte: | ||||||
|  | 			results[idx] = string(v) | ||||||
|  | 		case uint: | ||||||
|  | 			results[idx] = strconv.FormatUint(uint64(v), 10) | ||||||
|  | 		default: | ||||||
|  | 			results[idx] = fmt.Sprint(v) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return strings.Join(results, "_") | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu