Better support Stringer when explain SQL
This commit is contained in:
		
							parent
							
								
									d22215129e
								
							
						
					
					
						commit
						4eeb839cea
					
				| @ -12,6 +12,7 @@ import ( | |||||||
| 	"gorm.io/gorm/utils" | 	"gorm.io/gorm/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // ErrRecordNotFound record not found error
 | ||||||
| var ErrRecordNotFound = errors.New("record not found") | var ErrRecordNotFound = errors.New("record not found") | ||||||
| 
 | 
 | ||||||
| // Colors
 | // Colors
 | ||||||
| @ -30,13 +31,17 @@ const ( | |||||||
| 	YellowBold  = "\033[33;1m" | 	YellowBold  = "\033[33;1m" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // LogLevel
 | // LogLevel log level
 | ||||||
| type LogLevel int | type LogLevel int | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | 	// Silent silent log level
 | ||||||
| 	Silent LogLevel = iota + 1 | 	Silent LogLevel = iota + 1 | ||||||
|  | 	// Error error log level
 | ||||||
| 	Error | 	Error | ||||||
|  | 	// Warn warn log level
 | ||||||
| 	Warn | 	Warn | ||||||
|  | 	// Info info log level
 | ||||||
| 	Info | 	Info | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -45,6 +50,7 @@ type Writer interface { | |||||||
| 	Printf(string, ...interface{}) | 	Printf(string, ...interface{}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Config logger config
 | ||||||
| type Config struct { | type Config struct { | ||||||
| 	SlowThreshold             time.Duration | 	SlowThreshold             time.Duration | ||||||
| 	Colorful                  bool | 	Colorful                  bool | ||||||
| @ -62,16 +68,20 @@ type Interface interface { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
|  | 	// Discard Discard logger will print any log to ioutil.Discard
 | ||||||
| 	Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) | 	Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) | ||||||
|  | 	// Default Default logger
 | ||||||
| 	Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ | 	Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ | ||||||
| 		SlowThreshold:             200 * time.Millisecond, | 		SlowThreshold:             200 * time.Millisecond, | ||||||
| 		LogLevel:                  Warn, | 		LogLevel:                  Warn, | ||||||
| 		IgnoreRecordNotFoundError: false, | 		IgnoreRecordNotFoundError: false, | ||||||
| 		Colorful:                  true, | 		Colorful:                  true, | ||||||
| 	}) | 	}) | ||||||
|  | 	// Recorder Recorder logger records running SQL into a recorder instance
 | ||||||
| 	Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} | 	Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // New initialize logger
 | ||||||
| func New(writer Writer, config Config) Interface { | func New(writer Writer, config Config) Interface { | ||||||
| 	var ( | 	var ( | ||||||
| 		infoStr      = "%s\n[info] " | 		infoStr      = "%s\n[info] " | ||||||
| @ -179,10 +189,12 @@ type traceRecorder struct { | |||||||
| 	Err          error | 	Err          error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // New new trace recorder
 | ||||||
| func (l traceRecorder) New() *traceRecorder { | func (l traceRecorder) New() *traceRecorder { | ||||||
| 	return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} | 	return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Trace implement logger interface
 | ||||||
| func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { | func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { | ||||||
| 	l.BeginAt = begin | 	l.BeginAt = begin | ||||||
| 	l.SQL, l.RowsAffected = fc() | 	l.SQL, l.RowsAffected = fc() | ||||||
|  | |||||||
| @ -30,9 +30,12 @@ func isPrintable(s []byte) bool { | |||||||
| 
 | 
 | ||||||
| var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} | var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} | ||||||
| 
 | 
 | ||||||
|  | // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
 | ||||||
| func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { | func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { | ||||||
| 	var convertParams func(interface{}, int) | 	var ( | ||||||
| 	vars := make([]string, len(avars)) | 		convertParams func(interface{}, int) | ||||||
|  | 		vars          = make([]string, len(avars)) | ||||||
|  | 	) | ||||||
| 
 | 
 | ||||||
| 	convertParams = func(v interface{}, idx int) { | 	convertParams = func(v interface{}, idx int) { | ||||||
| 		switch v := v.(type) { | 		switch v := v.(type) { | ||||||
| @ -64,10 +67,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a | |||||||
| 			} | 			} | ||||||
| 		case fmt.Stringer: | 		case fmt.Stringer: | ||||||
| 			reflectValue := reflect.ValueOf(v) | 			reflectValue := reflect.ValueOf(v) | ||||||
| 			if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { | 			switch reflectValue.Kind() { | ||||||
|  | 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||||
|  | 				vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) | ||||||
|  | 			case reflect.Float32, reflect.Float64: | ||||||
|  | 				vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) | ||||||
|  | 			case reflect.Bool: | ||||||
|  | 				vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) | ||||||
|  | 			case reflect.String: | ||||||
| 				vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper | 				vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper | ||||||
| 			} else { | 			default: | ||||||
| 				vars[idx] = nullStr | 				if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { | ||||||
|  | 					vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper | ||||||
|  | 				} else { | ||||||
|  | 					vars[idx] = nullStr | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		case []byte: | 		case []byte: | ||||||
| 			if isPrintable(v) { | 			if isPrintable(v) { | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ require ( | |||||||
| 	github.com/jinzhu/now v1.1.4 | 	github.com/jinzhu/now v1.1.4 | ||||||
| 	github.com/lib/pq v1.10.4 | 	github.com/lib/pq v1.10.4 | ||||||
| 	github.com/mattn/go-sqlite3 v1.14.11 // indirect | 	github.com/mattn/go-sqlite3 v1.14.11 // indirect | ||||||
| 	golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect | 	golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect | ||||||
| 	gorm.io/driver/mysql v1.2.3 | 	gorm.io/driver/mysql v1.2.3 | ||||||
| 	gorm.io/driver/postgres v1.2.3 | 	gorm.io/driver/postgres v1.2.3 | ||||||
| 	gorm.io/driver/sqlite v1.2.6 | 	gorm.io/driver/sqlite v1.2.6 | ||||||
|  | |||||||
| @ -168,6 +168,59 @@ func TestDryRun(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type ageInt int8 | ||||||
|  | 
 | ||||||
|  | func (ageInt) String() string { | ||||||
|  | 	return "age" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ageBool bool | ||||||
|  | 
 | ||||||
|  | func (ageBool) String() string { | ||||||
|  | 	return "age" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ageUint64 uint64 | ||||||
|  | 
 | ||||||
|  | func (ageUint64) String() string { | ||||||
|  | 	return "age" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ageFloat float64 | ||||||
|  | 
 | ||||||
|  | func (ageFloat) String() string { | ||||||
|  | 	return "age" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestExplainSQL(t *testing.T) { | ||||||
|  | 	user := *GetUser("explain-sql", Config{}) | ||||||
|  | 	dryRunDB := DB.Session(&gorm.Session{DryRun: true}) | ||||||
|  | 
 | ||||||
|  | 	stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement | ||||||
|  | 	sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||||
|  | 	if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { | ||||||
|  | 		t.Errorf("Failed to generate sql, got %v", sql) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement | ||||||
|  | 	sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||||
|  | 	if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { | ||||||
|  | 		t.Errorf("Failed to generate sql, got %v", sql) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement | ||||||
|  | 	sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||||
|  | 	if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { | ||||||
|  | 		t.Errorf("Failed to generate sql, got %v", sql) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement | ||||||
|  | 	sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||||
|  | 	if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { | ||||||
|  | 		t.Errorf("Failed to generate sql, got %v", sql) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestGroupConditions(t *testing.T) { | func TestGroupConditions(t *testing.T) { | ||||||
| 	type Pizza struct { | 	type Pizza struct { | ||||||
| 		ID   uint | 		ID   uint | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu