Explain SQL for dialects
This commit is contained in:
		
							parent
							
								
									fe24c3f105
								
							
						
					
					
						commit
						bc5ceff82f
					
				| @ -3,6 +3,7 @@ package gorm | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| @ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor { | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Execute(db *DB) { | ||||
| 	curTime := time.Now() | ||||
| 	if stmt := db.Statement; stmt != nil { | ||||
| 		if stmt.Model == nil { | ||||
| 			stmt.Model = stmt.Dest | ||||
| @ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) { | ||||
| 	for _, f := range p.fns { | ||||
| 		f(db) | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt := db.Statement; stmt != nil { | ||||
| 		db.Logger.RunWith(logger.Info, func() { | ||||
| 			db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Get(name string) func(*DB) { | ||||
|  | ||||
| @ -3,11 +3,13 @@ package mssql | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	_ "github.com/denisenkom/go-mssqldb" | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/migrator" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| @ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte { | ||||
| 	return [2]byte{'"', '"'} // `name`
 | ||||
| } | ||||
| 
 | ||||
| var numericPlaceholder = regexp.MustCompile("@p(\\d+)") | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
|  | ||||
| @ -2,6 +2,7 @@ package mssql_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| @ -15,7 +16,12 @@ var ( | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
| 	if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil { | ||||
| 	dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" | ||||
| 	if os.Getenv("GORM_DSN") != "" { | ||||
| 		dsn = os.Getenv("GORM_DSN") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil { | ||||
| 		panic(fmt.Sprintf("failed to initialize database, got error %v", err)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -8,6 +8,7 @@ import ( | ||||
| 	_ "github.com/go-sql-driver/mysql" | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/migrator" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| @ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte { | ||||
| 	return [2]byte{'`', '`'} // `name`
 | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, nil, `"`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
|  | ||||
| @ -2,6 +2,7 @@ package mysql_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| @ -15,7 +16,12 @@ var ( | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
| 	if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil { | ||||
| 	dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" | ||||
| 	if os.Getenv("GORM_DSN") != "" { | ||||
| 		dsn = os.Getenv("GORM_DSN") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil { | ||||
| 		panic(fmt.Sprintf("failed to initialize database, got error %v", err)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -3,10 +3,12 @@ package postgres | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/migrator" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| 	_ "github.com/lib/pq" | ||||
| @ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte { | ||||
| 	return [2]byte{'"', '"'} // "name"
 | ||||
| } | ||||
| 
 | ||||
| var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
|  | ||||
| @ -2,6 +2,7 @@ package postgres_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| @ -15,7 +16,12 @@ var ( | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
| 	if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil { | ||||
| 	dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" | ||||
| 	if os.Getenv("GORM_DSN") != "" { | ||||
| 		dsn = os.Getenv("GORM_DSN") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil { | ||||
| 		panic(fmt.Sprintf("failed to initialize database, got error %v", err)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/migrator" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| @ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte { | ||||
| 	return [2]byte{'`', '`'} // `name`
 | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string { | ||||
| 	return logger.ExplainSQL(sql, nil, `"`, vars...) | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
|  | ||||
| @ -14,6 +14,7 @@ type Dialector interface { | ||||
| 	DataTypeOf(*schema.Field) string | ||||
| 	BindVar(stmt *Statement, v interface{}) string | ||||
| 	QuoteChars() [2]byte | ||||
| 	Explain(sql string, vars ...interface{}) string | ||||
| } | ||||
| 
 | ||||
| // CommonDB common db interface
 | ||||
|  | ||||
| @ -11,9 +11,9 @@ type LogLevel int | ||||
| var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} | ||||
| 
 | ||||
| const ( | ||||
| 	Info LogLevel = iota + 1 | ||||
| 	Error LogLevel = iota + 1 | ||||
| 	Warn | ||||
| 	Error | ||||
| 	Info | ||||
| ) | ||||
| 
 | ||||
| // Interface logger interface
 | ||||
| @ -22,6 +22,7 @@ type Interface interface { | ||||
| 	Info(string, ...interface{}) | ||||
| 	Warn(string, ...interface{}) | ||||
| 	Error(string, ...interface{}) | ||||
| 	RunWith(LogLevel, func()) | ||||
| } | ||||
| 
 | ||||
| // Writer log writer interface
 | ||||
| @ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface { | ||||
| 
 | ||||
| // Info print info
 | ||||
| func (logger Logger) Info(msg string, data ...interface{}) { | ||||
| 	if logger.logLevel <= Info { | ||||
| 	if logger.logLevel >= Info { | ||||
| 		logger.Print("[info] " + fmt.Sprintf(msg, data...)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Warn print warn messages
 | ||||
| func (logger Logger) Warn(msg string, data ...interface{}) { | ||||
| 	if logger.logLevel <= Warn { | ||||
| 	if logger.logLevel >= Warn { | ||||
| 		logger.Print("[warn] " + fmt.Sprintf(msg, data...)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Error print error messages
 | ||||
| func (logger Logger) Error(msg string, data ...interface{}) { | ||||
| 	if logger.logLevel <= Error { | ||||
| 	if logger.logLevel >= Error { | ||||
| 		logger.Print("[error] " + fmt.Sprintf(msg, data...)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (logger Logger) RunWith(logLevel LogLevel, fc func()) { | ||||
| 	if logger.logLevel >= logLevel { | ||||
| 		fc() | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										68
									
								
								logger/sql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								logger/sql.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,68 @@ | ||||
| package logger | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode" | ||||
| ) | ||||
| 
 | ||||
| func isPrintable(s []byte) bool { | ||||
| 	for _, r := range s { | ||||
| 		if !unicode.IsPrint(rune(r)) { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { | ||||
| 	for idx, v := range vars { | ||||
| 		if valuer, ok := v.(driver.Valuer); ok { | ||||
| 			v, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		switch v := v.(type) { | ||||
| 		case bool: | ||||
| 			vars[idx] = fmt.Sprint(v) | ||||
| 		case time.Time: | ||||
| 			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper | ||||
| 		case *time.Time: | ||||
| 			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper | ||||
| 		case []byte: | ||||
| 			if isPrintable(v) { | ||||
| 				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper | ||||
| 			} else { | ||||
| 				vars[idx] = escaper + "<binary>" + escaper | ||||
| 			} | ||||
| 		case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: | ||||
| 			vars[idx] = fmt.Sprintf("%d", v) | ||||
| 		case float64, float32: | ||||
| 			vars[idx] = fmt.Sprintf("%.6f", v) | ||||
| 		case string: | ||||
| 			vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper | ||||
| 		default: | ||||
| 			if v == nil { | ||||
| 				vars[idx] = "NULL" | ||||
| 			} else { | ||||
| 				vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if numericPlaceholder == nil { | ||||
| 		for _, v := range vars { | ||||
| 			sql = strings.Replace(sql, "?", v.(string), 1) | ||||
| 		} | ||||
| 	} else { | ||||
| 		sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1") | ||||
| 		for idx, v := range vars { | ||||
| 			sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return sql | ||||
| } | ||||
							
								
								
									
										45
									
								
								logger/sql_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								logger/sql_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,45 @@ | ||||
| package logger_test | ||||
| 
 | ||||
| import ( | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/now" | ||||
| ) | ||||
| 
 | ||||
| func TestExplainSQL(t *testing.T) { | ||||
| 	tt := now.MustParse("2020-02-23 11:10:10") | ||||
| 
 | ||||
| 	results := []struct { | ||||
| 		SQL           string | ||||
| 		NumericRegexp *regexp.Regexp | ||||
| 		Vars          []interface{} | ||||
| 		Result        string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)", | ||||
| 			NumericRegexp: regexp.MustCompile("@p(\\d+)"), | ||||
| 			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)", | ||||
| 			NumericRegexp: regexp.MustCompile("\\$(\\d+)"), | ||||
| 			Vars:          []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, r := range results { | ||||
| 		if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result { | ||||
| 			t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,4 +1,4 @@ | ||||
| dialects=("postgres" "mysql" "mssql" "sqlite") | ||||
| dialects=("sqlite" "mysql" "postgres" "mssql") | ||||
| 
 | ||||
| if [[ $(pwd) == *"gorm/tests"* ]]; then | ||||
|   cd .. | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu