Work on clauses
This commit is contained in:
		
							parent
							
								
									8cb15cadde
								
							
						
					
					
						commit
						d833efe8b9
					
				
							
								
								
									
										13
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -1,9 +1,11 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| 	"github.com/jinzhu/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| @ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor { | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Execute(db *DB) { | ||||
| 	if stmt := db.Statement; stmt != nil && stmt.Dest != nil { | ||||
| 		var err error | ||||
| 		stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) | ||||
| 
 | ||||
| 		if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { | ||||
| 			db.AddError(err) | ||||
| 		} else if stmt.Table == "" && stmt.Schema != nil { | ||||
| 			stmt.Table = stmt.Schema.Table | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, f := range p.fns { | ||||
| 		f(db) | ||||
| 	} | ||||
|  | ||||
| @ -1,6 +1,10 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import "github.com/jinzhu/gorm" | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func BeforeCreate(db *gorm.DB) { | ||||
| 	// before save
 | ||||
| @ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Create(db *gorm.DB) { | ||||
| 	db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") | ||||
| 
 | ||||
| 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||
| } | ||||
| 
 | ||||
| func SaveAfterAssociations(db *gorm.DB) { | ||||
| @ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) { | ||||
| 	// after save
 | ||||
| 	// after create
 | ||||
| } | ||||
| 
 | ||||
| func objectToFieldsMap(stmt *gorm.Statement) { | ||||
| 	if stmt.Schema != nil { | ||||
| 		if s, ok := stmt.Clauses["SELECT"]; ok { | ||||
| 			s.Attrs | ||||
| 		} | ||||
| 
 | ||||
| 		if s, ok := stmt.Clauses["OMIT"]; ok { | ||||
| 			s.Attrs | ||||
| 		} | ||||
| 
 | ||||
| 		stmt.Schema.LookUpField(s.S) | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										13
									
								
								callbacks/query.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								callbacks/query.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import "github.com/jinzhu/gorm" | ||||
| 
 | ||||
| func Query(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Preload(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func AfterQuery(db *gorm.DB) { | ||||
| 	// after find
 | ||||
| } | ||||
							
								
								
									
										129
									
								
								clause/clause.go
									
									
									
									
									
								
							
							
						
						
									
										129
									
								
								clause/clause.go
									
									
									
									
									
								
							| @ -51,124 +51,21 @@ type OverrideNameInterface interface { | ||||
| 	OverrideName() string | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| // Predefined Clauses
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| 
 | ||||
| // Where where clause
 | ||||
| type Where struct { | ||||
| 	AndConditions AddConditions | ||||
| 	ORConditions  []ORConditions | ||||
| 	builders      []Expression | ||||
| // Column quote with name
 | ||||
| type Column struct { | ||||
| 	Table string | ||||
| 	Name  string | ||||
| 	Alias string | ||||
| 	Raw   bool | ||||
| } | ||||
| 
 | ||||
| func (where Where) Name() string { | ||||
| 	return "WHERE" | ||||
| func ToColumns(value ...interface{}) []Column { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (where Where) Build(builder Builder) { | ||||
| 	var withConditions bool | ||||
| 
 | ||||
| 	if len(where.AndConditions) > 0 { | ||||
| 		withConditions = true | ||||
| 		where.AndConditions.Build(builder) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(where.builders) > 0 { | ||||
| 		for _, b := range where.builders { | ||||
| 			if withConditions { | ||||
| 				builder.Write(" AND ") | ||||
| 			} | ||||
| 			withConditions = true | ||||
| 			b.Build(builder) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	var singleOrConditions []ORConditions | ||||
| 	for _, or := range where.ORConditions { | ||||
| 		if len(or) == 1 { | ||||
| 			if withConditions { | ||||
| 				builder.Write(" OR ") | ||||
| 				or.Build(builder) | ||||
| 			} else { | ||||
| 				singleOrConditions = append(singleOrConditions, or) | ||||
| 			} | ||||
| 		} else { | ||||
| 			withConditions = true | ||||
| 			builder.Write(" AND (") | ||||
| 			or.Build(builder) | ||||
| 			builder.WriteByte(')') | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, or := range singleOrConditions { | ||||
| 		if withConditions { | ||||
| 			builder.Write(" AND ") | ||||
| 			or.Build(builder) | ||||
| 		} else { | ||||
| 			withConditions = true | ||||
| 			or.Build(builder) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !withConditions { | ||||
| 		builder.Write(" FALSE") | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (where Where) MergeExpression(expr Expression) { | ||||
| 	if w, ok := expr.(Where); ok { | ||||
| 		where.AndConditions = append(where.AndConditions, w.AndConditions...) | ||||
| 		where.ORConditions = append(where.ORConditions, w.ORConditions...) | ||||
| 		where.builders = append(where.builders, w.builders...) | ||||
| 	} else { | ||||
| 		where.builders = append(where.builders, expr) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Select select attrs when querying, updating, creating
 | ||||
| type Select struct { | ||||
| 	Omit bool | ||||
| } | ||||
| 
 | ||||
| // Join join clause
 | ||||
| type Join struct { | ||||
| 	Table    string | ||||
| 	Type     string // left join books on
 | ||||
| 	ON       []Expression | ||||
| 	builders []Expression | ||||
| } | ||||
| 
 | ||||
| func (join Join) Build(builder Builder) { | ||||
| 	// TODO
 | ||||
| } | ||||
| 
 | ||||
| func (join Join) MergeExpression(expr Expression) { | ||||
| 	if j, ok := expr.(Join); ok { | ||||
| 		join.builders = append(join.builders, j.builders...) | ||||
| 	} else { | ||||
| 		join.builders = append(join.builders, expr) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // GroupBy group by clause
 | ||||
| type GroupBy struct { | ||||
| } | ||||
| 
 | ||||
| // Having having clause
 | ||||
| type Having struct { | ||||
| } | ||||
| 
 | ||||
| // Order order clause
 | ||||
| type Order struct { | ||||
| } | ||||
| 
 | ||||
| // Limit limit clause
 | ||||
| type Limit struct { | ||||
| } | ||||
| 
 | ||||
| // Offset offset clause
 | ||||
| type Offset struct { | ||||
| // Table quote with name
 | ||||
| type Table struct { | ||||
| 	Table string | ||||
| 	Alias string | ||||
| 	Raw   bool | ||||
| } | ||||
|  | ||||
							
								
								
									
										22
									
								
								clause/from.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								clause/from.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | ||||
| package clause | ||||
| 
 | ||||
| // From from clause
 | ||||
| type From struct { | ||||
| 	Tables []Table | ||||
| } | ||||
| 
 | ||||
| // Name from clause name
 | ||||
| func (From) Name() string { | ||||
| 	return "FROM" | ||||
| } | ||||
| 
 | ||||
| // Build build from clause
 | ||||
| func (from From) Build(builder Builder) { | ||||
| 	for idx, table := range from.Tables { | ||||
| 		if idx > 0 { | ||||
| 			builder.WriteByte(',') | ||||
| 		} | ||||
| 
 | ||||
| 		builder.WriteQuoted(table) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										6
									
								
								clause/group_by.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								clause/group_by.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| package clause | ||||
| 
 | ||||
| // GroupBy group by clause
 | ||||
| type GroupBy struct { | ||||
| 	Having Where | ||||
| } | ||||
							
								
								
									
										23
									
								
								clause/join.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								clause/join.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | ||||
| package clause | ||||
| 
 | ||||
| // Join join clause
 | ||||
| type Join struct { | ||||
| 	Table From   // From
 | ||||
| 	Type  string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
 | ||||
| 	Using []Column | ||||
| 	ON    Where | ||||
| } | ||||
| 
 | ||||
| // TODO multiple joins
 | ||||
| 
 | ||||
| func (join Join) Build(builder Builder) { | ||||
| 	// TODO
 | ||||
| } | ||||
| 
 | ||||
| func (join Join) MergeExpression(expr Expression) { | ||||
| 	// if j, ok := expr.(Join); ok {
 | ||||
| 	// 	join.builders = append(join.builders, j.builders...)
 | ||||
| 	// } else {
 | ||||
| 	// 	join.builders = append(join.builders, expr)
 | ||||
| 	// }
 | ||||
| } | ||||
							
								
								
									
										6
									
								
								clause/limit.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								clause/limit.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| package clause | ||||
| 
 | ||||
| // Limit limit clause
 | ||||
| type Limit struct { | ||||
| 	Offset uint | ||||
| } | ||||
							
								
								
									
										4
									
								
								clause/order_by.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								clause/order_by.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | ||||
| package clause | ||||
| 
 | ||||
| type OrderBy struct { | ||||
| } | ||||
| @ -2,12 +2,6 @@ package clause | ||||
| 
 | ||||
| import "strings" | ||||
| 
 | ||||
| // Column quote with name
 | ||||
| type Column struct { | ||||
| 	Table string | ||||
| 	Name  string | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| // Query Expressions
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
|  | ||||
							
								
								
									
										45
									
								
								clause/select.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								clause/select.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,45 @@ | ||||
| package clause | ||||
| 
 | ||||
| // Select select attrs when querying, updating, creating
 | ||||
| type Select struct { | ||||
| 	SelectColumns []Column | ||||
| 	OmitColumns   []Column | ||||
| } | ||||
| 
 | ||||
| // SelectInterface select clause interface
 | ||||
| type SelectInterface interface { | ||||
| 	Selects() []Column | ||||
| 	Omits() []Column | ||||
| } | ||||
| 
 | ||||
| func (s Select) Selects() []Column { | ||||
| 	return s.SelectColumns | ||||
| } | ||||
| 
 | ||||
| func (s Select) Omits() []Column { | ||||
| 	return s.OmitColumns | ||||
| } | ||||
| 
 | ||||
| func (s Select) Build(builder Builder) { | ||||
| 	if len(s.SelectColumns) > 0 { | ||||
| 		for idx, column := range s.SelectColumns { | ||||
| 			if idx > 0 { | ||||
| 				builder.WriteByte(',') | ||||
| 			} | ||||
| 			builder.WriteQuoted(column) | ||||
| 		} | ||||
| 	} else { | ||||
| 		builder.WriteByte('*') | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s Select) MergeExpression(expr Expression) { | ||||
| 	if v, ok := expr.(SelectInterface); ok { | ||||
| 		if len(s.SelectColumns) == 0 { | ||||
| 			s.SelectColumns = v.Selects() | ||||
| 		} | ||||
| 		if len(s.OmitColumns) == 0 { | ||||
| 			s.OmitColumns = v.Omits() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										77
									
								
								clause/where.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								clause/where.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | ||||
| package clause | ||||
| 
 | ||||
| // Where where clause
 | ||||
| type Where struct { | ||||
| 	AndConditions AddConditions | ||||
| 	ORConditions  []ORConditions | ||||
| 	builders      []Expression | ||||
| } | ||||
| 
 | ||||
| // Name where clause name
 | ||||
| func (where Where) Name() string { | ||||
| 	return "WHERE" | ||||
| } | ||||
| 
 | ||||
| // Build build where clause
 | ||||
| func (where Where) Build(builder Builder) { | ||||
| 	var withConditions bool | ||||
| 
 | ||||
| 	if len(where.AndConditions) > 0 { | ||||
| 		withConditions = true | ||||
| 		where.AndConditions.Build(builder) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(where.builders) > 0 { | ||||
| 		for _, b := range where.builders { | ||||
| 			if withConditions { | ||||
| 				builder.Write(" AND ") | ||||
| 			} | ||||
| 			withConditions = true | ||||
| 			b.Build(builder) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	var singleOrConditions []ORConditions | ||||
| 	for _, or := range where.ORConditions { | ||||
| 		if len(or) == 1 { | ||||
| 			if withConditions { | ||||
| 				builder.Write(" OR ") | ||||
| 				or.Build(builder) | ||||
| 			} else { | ||||
| 				singleOrConditions = append(singleOrConditions, or) | ||||
| 			} | ||||
| 		} else { | ||||
| 			withConditions = true | ||||
| 			builder.Write(" AND (") | ||||
| 			or.Build(builder) | ||||
| 			builder.WriteByte(')') | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, or := range singleOrConditions { | ||||
| 		if withConditions { | ||||
| 			builder.Write(" AND ") | ||||
| 			or.Build(builder) | ||||
| 		} else { | ||||
| 			withConditions = true | ||||
| 			or.Build(builder) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !withConditions { | ||||
| 		builder.Write(" FALSE") | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // MergeExpression merge where clauses
 | ||||
| func (where Where) MergeExpression(expr Expression) { | ||||
| 	if w, ok := expr.(Where); ok { | ||||
| 		where.AndConditions = append(where.AndConditions, w.AndConditions...) | ||||
| 		where.ORConditions = append(where.ORConditions, w.ORConditions...) | ||||
| 		where.builders = append(where.builders, w.builders...) | ||||
| 	} else { | ||||
| 		where.builders = append(where.builders, expr) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										4
									
								
								clause/with.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								clause/with.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | ||||
| package clause | ||||
| 
 | ||||
| type With struct { | ||||
| } | ||||
| @ -1,7 +1,5 @@ | ||||
| module github.com/jinzhu/gorm/dialects/mysql | ||||
| module github.com/jinzhu/gorm/dialects/sqlite | ||||
| 
 | ||||
| go 1.13 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible | ||||
| ) | ||||
| require github.com/mattn/go-sqlite3 v2.0.3+incompatible | ||||
|  | ||||
							
								
								
									
										2
									
								
								dialects/sqlite/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								dialects/sqlite/go.sum
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= | ||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | ||||
| @ -1,6 +1,7 @@ | ||||
| package sqlite | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/callbacks" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
|  | ||||
| @ -1,15 +1,27 @@ | ||||
| package sqlite_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/dialects/sqlite" | ||||
| 	"github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| var DB *gorm.DB | ||||
| var ( | ||||
| 	DB  *gorm.DB | ||||
| 	err error | ||||
| ) | ||||
| 
 | ||||
| func TestOpen(t *testing.T) { | ||||
| 	db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) | ||||
| func init() { | ||||
| 	if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil { | ||||
| 		panic(fmt.Sprintf("failed to initialize database, got error %v", err)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSqlite(t *testing.T) { | ||||
| 	tests.RunTestsSuit(t, DB) | ||||
| } | ||||
|  | ||||
| @ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { | ||||
| // First find first record that match given conditions, order by primary key
 | ||||
| func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) | ||||
| 	tx.Statement.Dest = out | ||||
| 	tx.Limit(1) | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Row() *sql.Row { | ||||
| 	// TODO
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Rows() (*sql.Rows, error) { | ||||
| 	// TODO
 | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @ -2,4 +2,7 @@ module github.com/jinzhu/gorm | ||||
| 
 | ||||
| go 1.13 | ||||
| 
 | ||||
| require github.com/jinzhu/inflection v1.0.0 | ||||
| require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
| 	gopkg.in/errgo.v2 v2.1.0 | ||||
| ) | ||||
|  | ||||
							
								
								
									
										43
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								gorm.go
									
									
									
									
									
								
							| @ -2,6 +2,7 @@ package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| @ -12,36 +13,28 @@ import ( | ||||
| // Config GORM config
 | ||||
| type Config struct { | ||||
| 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | ||||
| 	// You can cancel it by setting `SkipDefaultTransaction` to true
 | ||||
| 	SkipDefaultTransaction bool // TODO
 | ||||
| 
 | ||||
| 	// You can disable it by setting `SkipDefaultTransaction` to true
 | ||||
| 	SkipDefaultTransaction bool | ||||
| 	// NamingStrategy tables, columns naming strategy
 | ||||
| 	NamingStrategy schema.Namer | ||||
| 
 | ||||
| 	// Logger
 | ||||
| 	Logger logger.Interface | ||||
| 
 | ||||
| 	// NowFunc the function to be used when creating a new timestamp
 | ||||
| 	NowFunc func() time.Time | ||||
| } | ||||
| 
 | ||||
| // Dialector GORM database dialector
 | ||||
| type Dialector interface { | ||||
| 	Initialize(*DB) error | ||||
| 	Migrator() Migrator | ||||
| 	BindVar(stmt Statement, v interface{}) string | ||||
| } | ||||
| 
 | ||||
| // DB GORM DB definition
 | ||||
| type DB struct { | ||||
| 	*Config | ||||
| 	Dialector | ||||
| 	Instance | ||||
| 	clone     bool | ||||
| 	callbacks *callbacks | ||||
| 	DB         CommonDB | ||||
| 	clone      bool | ||||
| 	callbacks  *callbacks | ||||
| 	cacheStore *sync.Map | ||||
| } | ||||
| 
 | ||||
| // Session session config when create new session
 | ||||
| // Session session config when create session with Session() method
 | ||||
| type Session struct { | ||||
| 	Context context.Context | ||||
| 	Logger  logger.Interface | ||||
| @ -67,10 +60,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | ||||
| 	} | ||||
| 
 | ||||
| 	db = &DB{ | ||||
| 		Config:    config, | ||||
| 		Dialector: dialector, | ||||
| 		clone:     true, | ||||
| 		callbacks: InitializeCallbacks(), | ||||
| 		Config:     config, | ||||
| 		Dialector:  dialector, | ||||
| 		clone:      true, | ||||
| 		callbacks:  InitializeCallbacks(), | ||||
| 		cacheStore: &sync.Map{}, | ||||
| 	} | ||||
| 
 | ||||
| 	if dialector != nil { | ||||
| @ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) { | ||||
| 	return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Close() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Set store value with key into current db instance's context
 | ||||
| func (db *DB) Set(key string, value interface{}) *DB { | ||||
| 	tx := db.getInstance() | ||||
| @ -145,12 +135,15 @@ func (db *DB) getInstance() *DB { | ||||
| 		} | ||||
| 
 | ||||
| 		return &DB{ | ||||
| 			Config:    db.Config, | ||||
| 			Dialector: db.Dialector, | ||||
| 			Instance: Instance{ | ||||
| 				Context:   ctx, | ||||
| 				Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, | ||||
| 			}, | ||||
| 			Config:     db.Config, | ||||
| 			Dialector:  db.Dialector, | ||||
| 			DB:         db.DB, | ||||
| 			callbacks:  db.callbacks, | ||||
| 			cacheStore: db.cacheStore, | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										21
									
								
								interfaces.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								interfaces.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| ) | ||||
| 
 | ||||
| // Dialector GORM database dialector
 | ||||
| type Dialector interface { | ||||
| 	Initialize(*DB) error | ||||
| 	Migrator() Migrator | ||||
| 	BindVar(stmt Statement, v interface{}) string | ||||
| } | ||||
| 
 | ||||
| // CommonDB common db interface
 | ||||
| type CommonDB interface { | ||||
| 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||
| 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||
| 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||
| } | ||||
| @ -1,6 +1,7 @@ | ||||
| package schema | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"go/ast" | ||||
| 	"reflect" | ||||
| @ -9,6 +10,9 @@ import ( | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| ) | ||||
| 
 | ||||
| // ErrUnsupportedDataType unsupported data type
 | ||||
| var ErrUnsupportedDataType = errors.New("unsupported data type") | ||||
| 
 | ||||
| type Schema struct { | ||||
| 	Name                    string | ||||
| 	ModelType               reflect.Type | ||||
| @ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | ||||
| 
 | ||||
| 	if modelType.Kind() != reflect.Struct { | ||||
| 		if modelType.PkgPath() == "" { | ||||
| 			return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) | ||||
| 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||
| 		} | ||||
| 		return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) | ||||
| 		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||
| 	} | ||||
| 
 | ||||
| 	if v, ok := cacheStore.Load(modelType); ok { | ||||
| @ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, field := range schema.Fields { | ||||
| 		if field.DBName == "" { | ||||
| 		if field.DBName == "" && field.DataType != "" { | ||||
| 			field.DBName = namer.ColumnName(schema.Table, field.Name) | ||||
| 		} | ||||
| 
 | ||||
|  | ||||
| @ -2,24 +2,16 @@ package schema_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| 	"github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | ||||
| 	t.Run("CheckSchema/"+s.Name, func(t *testing.T) { | ||||
| 		equalFieldNames := []string{"Name", "Table"} | ||||
| 
 | ||||
| 		for _, name := range equalFieldNames { | ||||
| 			got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() | ||||
| 			expects := reflect.ValueOf(v).FieldByName(name).Interface() | ||||
| 			if !reflect.DeepEqual(got, expects) { | ||||
| 				t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) | ||||
| 			} | ||||
| 		} | ||||
| 		tests.AssertEqual(t, s, v, "Name", "Table") | ||||
| 
 | ||||
| 		for idx, field := range primaryFields { | ||||
| 			var found bool | ||||
| @ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* | ||||
| 		if parsedField, ok := s.FieldsByName[f.Name]; !ok { | ||||
| 			t.Errorf("schema %v failed to look up field with name %v", s, f.Name) | ||||
| 		} else { | ||||
| 			equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} | ||||
| 
 | ||||
| 			for _, name := range equalFieldNames { | ||||
| 				got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() | ||||
| 				expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() | ||||
| 				if !reflect.DeepEqual(got, expects) { | ||||
| 					t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) | ||||
| 				} | ||||
| 			} | ||||
| 			tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") | ||||
| 
 | ||||
| 			if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { | ||||
| 				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) | ||||
|  | ||||
							
								
								
									
										31
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								statement.go
									
									
									
									
									
								
							| @ -10,6 +10,7 @@ import ( | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| // Instance db instance
 | ||||
| @ -37,6 +38,7 @@ type Statement struct { | ||||
| 	Clauses  map[string]clause.Clause | ||||
| 	Settings sync.Map | ||||
| 	DB       *DB | ||||
| 	Schema   *schema.Schema | ||||
| 
 | ||||
| 	// SQL Builder
 | ||||
| 	SQL       strings.Builder | ||||
| @ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) { | ||||
| } | ||||
| 
 | ||||
| // Quote returns quoted value
 | ||||
| func (stmt Statement) Quote(field interface{}) (str string) { | ||||
| 	// FIXME
 | ||||
| 	return fmt.Sprint(field) | ||||
| func (stmt Statement) Quote(field interface{}) string { | ||||
| 	var str strings.Builder | ||||
| 
 | ||||
| 	switch v := field.(type) { | ||||
| 	case clause.Table: | ||||
| 		str.WriteString(v.Table) | ||||
| 		if v.Alias != "" { | ||||
| 			str.WriteString(" AS ") | ||||
| 			str.WriteString(v.Alias) | ||||
| 		} | ||||
| 	case clause.Column: | ||||
| 		if v.Table != "" { | ||||
| 			str.WriteString(v.Table) | ||||
| 			str.WriteByte('.') | ||||
| 		} | ||||
| 
 | ||||
| 		str.WriteString(v.Name) | ||||
| 		if v.Alias != "" { | ||||
| 			str.WriteString(" AS ") | ||||
| 			str.WriteString(v.Alias) | ||||
| 		} | ||||
| 	default: | ||||
| 		fmt.Sprint(field) | ||||
| 	} | ||||
| 
 | ||||
| 	return str.String() | ||||
| } | ||||
| 
 | ||||
| // Write write string
 | ||||
|  | ||||
| @ -1 +0,0 @@ | ||||
| package tests | ||||
							
								
								
									
										42
									
								
								tests/tests.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/tests.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func Now() *time.Time { | ||||
| 	now := time.Now() | ||||
| 	return &now | ||||
| } | ||||
| 
 | ||||
| func RunTestsSuit(t *testing.T, db *gorm.DB) { | ||||
| 	TestCreate(t, db) | ||||
| } | ||||
| 
 | ||||
| func TestCreate(t *testing.T, db *gorm.DB) { | ||||
| 	t.Run("Create", func(t *testing.T) { | ||||
| 		var user = User{ | ||||
| 			Name:     "create", | ||||
| 			Age:      18, | ||||
| 			Birthday: Now(), | ||||
| 		} | ||||
| 
 | ||||
| 		if err := db.Create(&user).Error; err != nil { | ||||
| 			t.Errorf("errors happened when create: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if user.ID == 0 { | ||||
| 			t.Errorf("user's primary key should has value after create, got : %v", user.ID) | ||||
| 		} | ||||
| 
 | ||||
| 		var newUser User | ||||
| 		if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { | ||||
| 			t.Errorf("errors happened when query: %v", err) | ||||
| 		} else { | ||||
| 			AssertEqual(t, newUser, user, "Name", "Age", "Birthday") | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										19
									
								
								tests/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								tests/utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func AssertEqual(t *testing.T, r, e interface{}, names ...string) { | ||||
| 	for _, name := range names { | ||||
| 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||
| 		expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||
| 
 | ||||
| 		if !reflect.DeepEqual(got, expects) { | ||||
| 			t.Run(name, func(t *testing.T) { | ||||
| 				t.Errorf("expects: %v, got %v", expects, got) | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu