Add create value converter
This commit is contained in:
		
							parent
							
								
									cbbf8f3d49
								
							
						
					
					
						commit
						15ce5b3cdd
					
				| @ -2,6 +2,7 @@ package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| @ -19,11 +20,15 @@ func SaveBeforeAssociations(db *gorm.DB) { | ||||
| 
 | ||||
| func Create(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||
| 		Table: clause.Table{Table: db.Statement.Table}, | ||||
| 		Table: clause.Table{Name: db.Statement.Table}, | ||||
| 	}) | ||||
| 	values, _ := ConvertToCreateValues(db.Statement) | ||||
| 	db.Statement.AddClause(values) | ||||
| 
 | ||||
| 	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 	fmt.Printf("%+v\n", values) | ||||
| 	fmt.Println(err) | ||||
| 	fmt.Println(result) | ||||
| 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||
| @ -36,3 +41,83 @@ func AfterCreate(db *gorm.DB) { | ||||
| 	// after save
 | ||||
| 	// after create
 | ||||
| } | ||||
| 
 | ||||
| // ConvertToCreateValues convert to create values
 | ||||
| func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { | ||||
| 	switch value := stmt.Dest.(type) { | ||||
| 	case map[string]interface{}: | ||||
| 		return ConvertMapToValues(stmt, value), nil | ||||
| 	case []map[string]interface{}: | ||||
| 		return ConvertSliceOfMapToValues(stmt, value), nil | ||||
| 	default: | ||||
| 		var ( | ||||
| 			values                    = clause.Values{} | ||||
| 			selectColumns, restricted = SelectAndOmitColumns(stmt) | ||||
| 			curTime                   = stmt.DB.NowFunc() | ||||
| 			isZero                    = false | ||||
| 			returnningValues          []map[string]interface{} | ||||
| 		) | ||||
| 
 | ||||
| 		for _, db := range stmt.Schema.DBNames { | ||||
| 			if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { | ||||
| 				values.Columns = append(values.Columns, clause.Column{Name: db}) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) | ||||
| 		switch reflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			values.Values = make([][]interface{}, reflectValue.Len()) | ||||
| 			for i := 0; i < reflectValue.Len(); i++ { | ||||
| 				rv := reflect.Indirect(reflectValue.Index(i)) | ||||
| 				values.Values[i] = make([]interface{}, len(values.Columns)) | ||||
| 				for idx, column := range values.Columns { | ||||
| 					field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| 					if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { | ||||
| 						if field.DefaultValueInterface != nil { | ||||
| 							values.Values[i][idx] = field.DefaultValueInterface | ||||
| 							field.Set(rv, field.DefaultValueInterface) | ||||
| 						} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { | ||||
| 							field.Set(rv, curTime) | ||||
| 							values.Values[i][idx], _ = field.ValueOf(rv) | ||||
| 						} else if field.HasDefaultValue { | ||||
| 							if len(returnningValues) == 0 { | ||||
| 								returnningValues = make([]map[string]interface{}, reflectValue.Len()) | ||||
| 							} | ||||
| 
 | ||||
| 							if returnningValues[i] == nil { | ||||
| 								returnningValues[i] = map[string]interface{}{} | ||||
| 							} | ||||
| 
 | ||||
| 							// FIXME
 | ||||
| 							returnningValues[i][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} | ||||
| 			for idx, column := range values.Columns { | ||||
| 				field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| 				if values.Values[0][idx], _ = field.ValueOf(reflectValue); isZero { | ||||
| 					if field.DefaultValueInterface != nil { | ||||
| 						values.Values[0][idx] = field.DefaultValueInterface | ||||
| 						field.Set(reflectValue, field.DefaultValueInterface) | ||||
| 					} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { | ||||
| 						field.Set(reflectValue, curTime) | ||||
| 						values.Values[0][idx], _ = field.ValueOf(reflectValue) | ||||
| 					} else if field.HasDefaultValue { | ||||
| 						if len(returnningValues) == 0 { | ||||
| 							returnningValues = make([]map[string]interface{}, 1) | ||||
| 						} | ||||
| 
 | ||||
| 						values.Values[0][idx] = clause.Expr{SQL: "DEFAULT"} | ||||
| 						returnningValues[0][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() | ||||
| 					} else if field.PrimaryKey { | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return values, returnningValues | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										97
									
								
								callbacks/helper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								callbacks/helper.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,97 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"sort" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| ) | ||||
| 
 | ||||
| // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | ||||
| func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { | ||||
| 	results := map[string]bool{} | ||||
| 
 | ||||
| 	// select columns
 | ||||
| 	for _, column := range stmt.Selects { | ||||
| 		if field := stmt.Schema.LookUpField(column); field != nil { | ||||
| 			results[field.DBName] = true | ||||
| 		} else { | ||||
| 			results[column] = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// omit columns
 | ||||
| 	for _, omit := range stmt.Omits { | ||||
| 		if field := stmt.Schema.LookUpField(omit); field != nil { | ||||
| 			results[field.DBName] = false | ||||
| 		} else { | ||||
| 			results[omit] = false | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return results, len(stmt.Selects) > 0 | ||||
| } | ||||
| 
 | ||||
| // ConvertMapToValues convert map to values
 | ||||
| func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { | ||||
| 	columns := make([]string, 0, len(mapValue)) | ||||
| 	selectColumns, restricted := SelectAndOmitColumns(stmt) | ||||
| 
 | ||||
| 	var keys []string | ||||
| 	for k, _ := range mapValue { | ||||
| 		keys = append(keys, k) | ||||
| 	} | ||||
| 	sort.Strings(keys) | ||||
| 
 | ||||
| 	for _, k := range keys { | ||||
| 		if field := stmt.Schema.LookUpField(k); field != nil { | ||||
| 			k = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||
| 			columns = append(columns, k) | ||||
| 			values.Values[0] = append(values.Values[0], mapValue[k]) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // ConvertSliceOfMapToValues convert slice of map to values
 | ||||
| func ConvertSliceOfMapToValues(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { | ||||
| 	var ( | ||||
| 		columns                   = []string{} | ||||
| 		result                    = map[string][]interface{}{} | ||||
| 		selectColumns, restricted = SelectAndOmitColumns(stmt) | ||||
| 	) | ||||
| 
 | ||||
| 	for idx, mapValue := range mapValues { | ||||
| 		for k, v := range mapValue { | ||||
| 			if field := stmt.Schema.LookUpField(k); field != nil { | ||||
| 				k = field.DBName | ||||
| 			} | ||||
| 
 | ||||
| 			if _, ok := result[k]; !ok { | ||||
| 				if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||
| 					result[k] = make([]interface{}, len(mapValues)) | ||||
| 					columns = append(columns, k) | ||||
| 				} else { | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			result[k][idx] = v | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Strings(columns) | ||||
| 	values.Values = make([][]interface{}, len(mapValues)) | ||||
| 	for idx, column := range columns { | ||||
| 		for i, v := range result[column] { | ||||
| 			if i == 0 { | ||||
| 				values.Values[i] = make([]interface{}, len(columns)) | ||||
| 			} | ||||
| 			values.Values[i][idx] = v | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| @ -99,7 +99,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { | ||||
| func (db *DB) Omit(columns ...string) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 
 | ||||
| 	if len(columns) == 1 && strings.Contains(columns[0], ",") { | ||||
| 	if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { | ||||
| 		tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) | ||||
| 	} else { | ||||
| 		tx.Statement.Omits = columns | ||||
|  | ||||
| @ -7,7 +7,7 @@ type Values struct { | ||||
| 
 | ||||
| // Name from clause name
 | ||||
| func (Values) Name() string { | ||||
| 	return "" | ||||
| 	return "VALUES" | ||||
| } | ||||
| 
 | ||||
| // Build build from clause
 | ||||
| @ -40,6 +40,7 @@ func (values Values) Build(builder Builder) { | ||||
| 
 | ||||
| // MergeClause merge values clauses
 | ||||
| func (values Values) MergeClause(clause *Clause) { | ||||
| 	clause.Name = "" | ||||
| 	if v, ok := clause.Expression.(Values); ok { | ||||
| 		values.Values = append(v.Values, values.Values...) | ||||
| 	} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu