Better support Stringer when explain SQL
This commit is contained in:
		
							parent
							
								
									d22215129e
								
							
						
					
					
						commit
						4eeb839cea
					
				| @ -12,6 +12,7 @@ import ( | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| // ErrRecordNotFound record not found error
 | ||||
| var ErrRecordNotFound = errors.New("record not found") | ||||
| 
 | ||||
| // Colors
 | ||||
| @ -30,13 +31,17 @@ const ( | ||||
| 	YellowBold  = "\033[33;1m" | ||||
| ) | ||||
| 
 | ||||
| // LogLevel
 | ||||
| // LogLevel log level
 | ||||
| type LogLevel int | ||||
| 
 | ||||
| const ( | ||||
| 	// Silent silent log level
 | ||||
| 	Silent LogLevel = iota + 1 | ||||
| 	// Error error log level
 | ||||
| 	Error | ||||
| 	// Warn warn log level
 | ||||
| 	Warn | ||||
| 	// Info info log level
 | ||||
| 	Info | ||||
| ) | ||||
| 
 | ||||
| @ -45,6 +50,7 @@ type Writer interface { | ||||
| 	Printf(string, ...interface{}) | ||||
| } | ||||
| 
 | ||||
| // Config logger config
 | ||||
| type Config struct { | ||||
| 	SlowThreshold             time.Duration | ||||
| 	Colorful                  bool | ||||
| @ -62,16 +68,20 @@ type Interface interface { | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	// Discard Discard logger will print any log to ioutil.Discard
 | ||||
| 	Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) | ||||
| 	// Default Default logger
 | ||||
| 	Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ | ||||
| 		SlowThreshold:             200 * time.Millisecond, | ||||
| 		LogLevel:                  Warn, | ||||
| 		IgnoreRecordNotFoundError: false, | ||||
| 		Colorful:                  true, | ||||
| 	}) | ||||
| 	// Recorder Recorder logger records running SQL into a recorder instance
 | ||||
| 	Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} | ||||
| ) | ||||
| 
 | ||||
| // New initialize logger
 | ||||
| func New(writer Writer, config Config) Interface { | ||||
| 	var ( | ||||
| 		infoStr      = "%s\n[info] " | ||||
| @ -179,10 +189,12 @@ type traceRecorder struct { | ||||
| 	Err          error | ||||
| } | ||||
| 
 | ||||
| // New new trace recorder
 | ||||
| func (l traceRecorder) New() *traceRecorder { | ||||
| 	return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} | ||||
| } | ||||
| 
 | ||||
| // Trace implement logger interface
 | ||||
| func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { | ||||
| 	l.BeginAt = begin | ||||
| 	l.SQL, l.RowsAffected = fc() | ||||
|  | ||||
| @ -30,9 +30,12 @@ func isPrintable(s []byte) bool { | ||||
| 
 | ||||
| var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} | ||||
| 
 | ||||
| // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
 | ||||
| func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { | ||||
| 	var convertParams func(interface{}, int) | ||||
| 	vars := make([]string, len(avars)) | ||||
| 	var ( | ||||
| 		convertParams func(interface{}, int) | ||||
| 		vars          = make([]string, len(avars)) | ||||
| 	) | ||||
| 
 | ||||
| 	convertParams = func(v interface{}, idx int) { | ||||
| 		switch v := v.(type) { | ||||
| @ -64,11 +67,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a | ||||
| 			} | ||||
| 		case fmt.Stringer: | ||||
| 			reflectValue := reflect.ValueOf(v) | ||||
| 			switch reflectValue.Kind() { | ||||
| 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 				vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) | ||||
| 			case reflect.Float32, reflect.Float64: | ||||
| 				vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) | ||||
| 			case reflect.Bool: | ||||
| 				vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) | ||||
| 			case reflect.String: | ||||
| 				vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper | ||||
| 			default: | ||||
| 				if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { | ||||
| 					vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper | ||||
| 				} else { | ||||
| 					vars[idx] = nullStr | ||||
| 				} | ||||
| 			} | ||||
| 		case []byte: | ||||
| 			if isPrintable(v) { | ||||
| 				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper | ||||
|  | ||||
| @ -9,7 +9,7 @@ require ( | ||||
| 	github.com/jinzhu/now v1.1.4 | ||||
| 	github.com/lib/pq v1.10.4 | ||||
| 	github.com/mattn/go-sqlite3 v1.14.11 // indirect | ||||
| 	golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect | ||||
| 	golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect | ||||
| 	gorm.io/driver/mysql v1.2.3 | ||||
| 	gorm.io/driver/postgres v1.2.3 | ||||
| 	gorm.io/driver/sqlite v1.2.6 | ||||
|  | ||||
| @ -168,6 +168,59 @@ func TestDryRun(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ageInt int8 | ||||
| 
 | ||||
| func (ageInt) String() string { | ||||
| 	return "age" | ||||
| } | ||||
| 
 | ||||
| type ageBool bool | ||||
| 
 | ||||
| func (ageBool) String() string { | ||||
| 	return "age" | ||||
| } | ||||
| 
 | ||||
| type ageUint64 uint64 | ||||
| 
 | ||||
| func (ageUint64) String() string { | ||||
| 	return "age" | ||||
| } | ||||
| 
 | ||||
| type ageFloat float64 | ||||
| 
 | ||||
| func (ageFloat) String() string { | ||||
| 	return "age" | ||||
| } | ||||
| 
 | ||||
| func TestExplainSQL(t *testing.T) { | ||||
| 	user := *GetUser("explain-sql", Config{}) | ||||
| 	dryRunDB := DB.Session(&gorm.Session{DryRun: true}) | ||||
| 
 | ||||
| 	stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement | ||||
| 	sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||
| 	if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { | ||||
| 		t.Errorf("Failed to generate sql, got %v", sql) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement | ||||
| 	sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||
| 	if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { | ||||
| 		t.Errorf("Failed to generate sql, got %v", sql) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement | ||||
| 	sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||
| 	if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { | ||||
| 		t.Errorf("Failed to generate sql, got %v", sql) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement | ||||
| 	sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||
| 	if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { | ||||
| 		t.Errorf("Failed to generate sql, got %v", sql) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGroupConditions(t *testing.T) { | ||||
| 	type Pizza struct { | ||||
| 		ID   uint | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu