Make inesrt into db works
This commit is contained in:
		
							parent
							
								
									868ae052a1
								
							
						
					
					
						commit
						fa22807e12
					
				| @ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) { | ||||
| 
 | ||||
| 	if stmt := db.Statement; stmt != nil { | ||||
| 		db.Logger.Trace(curTime, func() (string, int64) { | ||||
| 			return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected | ||||
| 			return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected | ||||
| 		}, db.Error) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| @ -11,8 +10,6 @@ import ( | ||||
| func BeforeCreate(db *gorm.DB) { | ||||
| 	// before save
 | ||||
| 	// before create
 | ||||
| 
 | ||||
| 	// assign timestamp
 | ||||
| } | ||||
| 
 | ||||
| func SaveBeforeAssociations(db *gorm.DB) { | ||||
| @ -22,16 +19,29 @@ func Create(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||
| 		Table: clause.Table{Name: db.Statement.Table}, | ||||
| 	}) | ||||
| 	values, _ := ConvertToCreateValues(db.Statement) | ||||
| 	db.Statement.AddClause(values) | ||||
| 	db.Statement.AddClause(ConvertToCreateValues(db.Statement)) | ||||
| 
 | ||||
| 	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 	fmt.Printf("%+v\n", values) | ||||
| 	fmt.Println(err) | ||||
| 	fmt.Println(result) | ||||
| 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||
| 	if err == nil { | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if insertID, err := result.LastInsertId(); err == nil { | ||||
| 				switch db.Statement.ReflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 						db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 						insertID-- | ||||
| 					} | ||||
| 				case reflect.Struct: | ||||
| 					db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 	} else { | ||||
| 		db.AddError(err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func SaveAfterAssociations(db *gorm.DB) { | ||||
| @ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| // ConvertToCreateValues convert to create values
 | ||||
| func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { | ||||
| func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { | ||||
| 	switch value := stmt.Dest.(type) { | ||||
| 	case map[string]interface{}: | ||||
| 		return ConvertMapToValues(stmt, value), nil | ||||
| 		return ConvertMapToValues(stmt, value) | ||||
| 	case []map[string]interface{}: | ||||
| 		return ConvertSliceOfMapToValues(stmt, value), nil | ||||
| 		return ConvertSliceOfMapToValues(stmt, value) | ||||
| 	default: | ||||
| 		var ( | ||||
| 			values                    = clause.Values{} | ||||
| 			selectColumns, restricted = SelectAndOmitColumns(stmt) | ||||
| 			curTime                   = stmt.DB.NowFunc() | ||||
| 			isZero                    = false | ||||
| 			returnningValues          []map[string]interface{} | ||||
| 		) | ||||
| 
 | ||||
| 		for _, db := range stmt.Schema.DBNames { | ||||
| @ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) | ||||
| 		switch reflectValue.Kind() { | ||||
| 		switch stmt.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			values.Values = make([][]interface{}, reflectValue.Len()) | ||||
| 			values.Values = make([][]interface{}, stmt.ReflectValue.Len()) | ||||
| 			defaultValueFieldsHavingValue := map[string][]interface{}{} | ||||
| 			for i := 0; i < reflectValue.Len(); i++ { | ||||
| 				rv := reflect.Indirect(reflectValue.Index(i)) | ||||
| 			for i := 0; i < stmt.ReflectValue.Len(); i++ { | ||||
| 				rv := reflect.Indirect(stmt.ReflectValue.Index(i)) | ||||
| 				values.Values[i] = make([]interface{}, len(values.Columns)) | ||||
| 				for idx, column := range values.Columns { | ||||
| 					field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| @ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in | ||||
| 					if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { | ||||
| 						if v, isZero := field.ValueOf(rv); !isZero { | ||||
| 							if len(defaultValueFieldsHavingValue[db]) == 0 { | ||||
| 								defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len()) | ||||
| 								defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) | ||||
| 							} | ||||
| 							defaultValueFieldsHavingValue[db][i] = v | ||||
| 						} | ||||
| @ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in | ||||
| 			values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} | ||||
| 			for idx, column := range values.Columns { | ||||
| 				field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| 				if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero { | ||||
| 				if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { | ||||
| 					if field.DefaultValueInterface != nil { | ||||
| 						values.Values[0][idx] = field.DefaultValueInterface | ||||
| 						field.Set(reflectValue, field.DefaultValueInterface) | ||||
| 						field.Set(stmt.ReflectValue, field.DefaultValueInterface) | ||||
| 					} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { | ||||
| 						field.Set(reflectValue, curTime) | ||||
| 						values.Values[0][idx], _ = field.ValueOf(reflectValue) | ||||
| 						field.Set(stmt.ReflectValue, curTime) | ||||
| 						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for db, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||
| 				if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { | ||||
| 					if v, isZero := field.ValueOf(reflectValue); !isZero { | ||||
| 					if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | ||||
| 						values.Columns = append(values.Columns, clause.Column{Name: db}) | ||||
| 						values.Values[0] = append(values.Values[0], v) | ||||
| 					} | ||||
| @ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		return values, returnningValues | ||||
| 		return values | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -1,8 +1,6 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| ) | ||||
| @ -15,10 +13,8 @@ func Query(db *gorm.DB) { | ||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 	} | ||||
| 
 | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	fmt.Println(err) | ||||
| 	fmt.Println(result) | ||||
| 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||
| 	rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	db.AddError(err) | ||||
| } | ||||
| 
 | ||||
| func Preload(db *gorm.DB) { | ||||
|  | ||||
| @ -93,29 +93,28 @@ type logger struct { | ||||
| 
 | ||||
| // LogMode log mode
 | ||||
| func (l logger) LogMode(level LogLevel) Interface { | ||||
| 	config := l.Config | ||||
| 	config.LogLevel = level | ||||
| 	return logger{Writer: l.Writer, Config: config} | ||||
| 	l.LogLevel = level | ||||
| 	return l | ||||
| } | ||||
| 
 | ||||
| // Info print info
 | ||||
| func (l logger) Info(msg string, data ...interface{}) { | ||||
| 	if l.LogLevel >= Info { | ||||
| 		l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) | ||||
| 		l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Warn print warn messages
 | ||||
| func (l logger) Warn(msg string, data ...interface{}) { | ||||
| 	if l.LogLevel >= Warn { | ||||
| 		l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) | ||||
| 		l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Error print error messages
 | ||||
| func (l logger) Error(msg string, data ...interface{}) { | ||||
| 	if l.LogLevel >= Error { | ||||
| 		l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) | ||||
| 		l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) { | ||||
| func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { | ||||
| 	if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold { | ||||
| 		sql, rows := fc() | ||||
| 		l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) | ||||
| 		fileline := utils.FileWithLineNum() | ||||
| 		if err != nil { | ||||
| 			fileline += " " + err.Error() | ||||
| 		} | ||||
| 		l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql) | ||||
| 	} else if l.LogLevel >= Info { | ||||
| 		sql, rows := fc() | ||||
| 		l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) | ||||
|  | ||||
| @ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v | ||||
| 		case bool: | ||||
| 			vars[idx] = fmt.Sprint(v) | ||||
| 		case time.Time: | ||||
| 			if v.IsZero() { | ||||
| 				vars[idx] = escaper + "0000-00-00 00:00:00" + escaper | ||||
| 			} else { | ||||
| 				vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper | ||||
| 			} | ||||
| 		case []byte: | ||||
| 			if isPrintable(v) { | ||||
| 				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper | ||||
| @ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v | ||||
| 				vars[idx] = "NULL" | ||||
| 			} else { | ||||
| 				rv := reflect.Indirect(reflect.ValueOf(v)) | ||||
| 				if !rv.IsValid() { | ||||
| 					vars[idx] = "NULL" | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				for _, t := range convertableTypes { | ||||
| 					if rv.Type().ConvertibleTo(t) { | ||||
| 						convertParams(rv.Convert(t).Interface(), idx) | ||||
|  | ||||
| @ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 		var err error | ||||
| 		field.Creatable = false | ||||
| 		field.Updatable = false | ||||
| 		if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { | ||||
| 		if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { | ||||
| 			schema.err = err | ||||
| 		} | ||||
| 		for _, ef := range field.EmbeddedSchema.Fields { | ||||
|  | ||||
| @ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) { | ||||
| 		} | ||||
| 	) | ||||
| 
 | ||||
| 	if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { | ||||
| 	if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { | ||||
| 		schema.err = err | ||||
| 		return | ||||
| 	} | ||||
| @ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { | ||||
| 	if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { | ||||
| 		schema.err = err | ||||
| 	} | ||||
| 	relation.JoinTable.Name = many2many | ||||
|  | ||||
| @ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field { | ||||
| } | ||||
| 
 | ||||
| // get data type from dialector
 | ||||
| func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { | ||||
| 	modelType := reflect.ValueOf(dest).Type() | ||||
| func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) { | ||||
| 	reflectValue := reflect.ValueOf(dest) | ||||
| 	modelType := reflectValue.Type() | ||||
| 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { | ||||
| 		modelType = modelType.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	if modelType.Kind() != reflect.Struct { | ||||
| 		if modelType.PkgPath() == "" { | ||||
| 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||
| 			return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||
| 		} | ||||
| 		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||
| 		return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||
| 	} | ||||
| 
 | ||||
| 	if v, ok := cacheStore.Load(modelType); ok { | ||||
| 		return v.(*Schema), nil | ||||
| 		return v.(*Schema), reflectValue, nil | ||||
| 	} | ||||
| 
 | ||||
| 	schema := &Schema{ | ||||
| @ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | ||||
| 	for _, field := range schema.Fields { | ||||
| 		if field.DataType == "" && field.Creatable { | ||||
| 			if schema.parseRelation(field); schema.err != nil { | ||||
| 				return schema, schema.err | ||||
| 				return schema, reflectValue, schema.err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return schema, schema.err | ||||
| 	return schema, reflectValue, schema.err | ||||
| } | ||||
|  | ||||
							
								
								
									
										12
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								statement.go
									
									
									
									
									
								
							| @ -5,6 +5,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| @ -32,7 +33,7 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { | ||||
| func (inst *Instance) AddError(err error) { | ||||
| 	if inst.Error == nil { | ||||
| 		inst.Error = err | ||||
| 	} else { | ||||
| 	} else if err != nil { | ||||
| 		inst.Error = fmt.Errorf("%v; %w", inst.Error, err) | ||||
| 	} | ||||
| } | ||||
| @ -42,6 +43,7 @@ type Statement struct { | ||||
| 	Table        string | ||||
| 	Model        interface{} | ||||
| 	Dest         interface{} | ||||
| 	ReflectValue reflect.Value | ||||
| 	Clauses      map[string]clause.Clause | ||||
| 	Selects      []string // selected columns
 | ||||
| 	Omits        []string // omit columns
 | ||||
| @ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { | ||||
| // BuildCondtion build condition
 | ||||
| func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { | ||||
| 	if sql, ok := query.(string); ok { | ||||
| 		if i, err := strconv.Atoi(sql); err != nil { | ||||
| 		if i, err := strconv.Atoi(sql); err == nil { | ||||
| 			query = i | ||||
| 		} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { | ||||
| 			return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} | ||||
| @ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) { | ||||
| } | ||||
| 
 | ||||
| func (stmt *Statement) Parse(value interface{}) (err error) { | ||||
| 	if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { | ||||
| 	if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { | ||||
| 		stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue) | ||||
| 
 | ||||
| 		if stmt.Table == "" { | ||||
| 			stmt.Table = stmt.Schema.Table | ||||
| 		} | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func TestCreate(t *testing.T, db *gorm.DB) { | ||||
| 	db.AutoMigrate(&User{}) | ||||
| 	db = db.Debug() | ||||
| 
 | ||||
| 	t.Run("Create", func(t *testing.T) { | ||||
| 		var user = User{ | ||||
| 			Name:     "create", | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu