Work on clauses
This commit is contained in:
		
							parent
							
								
									8cb15cadde
								
							
						
					
					
						commit
						d833efe8b9
					
				
							
								
								
									
										13
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -1,9 +1,11 @@ | |||||||
| package gorm | package gorm | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm/logger" | 	"github.com/jinzhu/gorm/logger" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| 	"github.com/jinzhu/gorm/utils" | 	"github.com/jinzhu/gorm/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *processor) Execute(db *DB) { | func (p *processor) Execute(db *DB) { | ||||||
|  | 	if stmt := db.Statement; stmt != nil && stmt.Dest != nil { | ||||||
|  | 		var err error | ||||||
|  | 		stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) | ||||||
|  | 
 | ||||||
|  | 		if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { | ||||||
|  | 			db.AddError(err) | ||||||
|  | 		} else if stmt.Table == "" && stmt.Schema != nil { | ||||||
|  | 			stmt.Table = stmt.Schema.Table | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for _, f := range p.fns { | 	for _, f := range p.fns { | ||||||
| 		f(db) | 		f(db) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -1,6 +1,10 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import "github.com/jinzhu/gorm" | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| func BeforeCreate(db *gorm.DB) { | func BeforeCreate(db *gorm.DB) { | ||||||
| 	// before save
 | 	// before save
 | ||||||
| @ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Create(db *gorm.DB) { | func Create(db *gorm.DB) { | ||||||
|  | 	db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") | ||||||
|  | 
 | ||||||
|  | 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func SaveAfterAssociations(db *gorm.DB) { | func SaveAfterAssociations(db *gorm.DB) { | ||||||
| @ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) { | |||||||
| 	// after save
 | 	// after save
 | ||||||
| 	// after create
 | 	// after create
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func objectToFieldsMap(stmt *gorm.Statement) { | ||||||
|  | 	if stmt.Schema != nil { | ||||||
|  | 		if s, ok := stmt.Clauses["SELECT"]; ok { | ||||||
|  | 			s.Attrs | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if s, ok := stmt.Clauses["OMIT"]; ok { | ||||||
|  | 			s.Attrs | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		stmt.Schema.LookUpField(s.S) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								callbacks/query.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								callbacks/query.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import "github.com/jinzhu/gorm" | ||||||
|  | 
 | ||||||
|  | func Query(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Preload(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func AfterQuery(db *gorm.DB) { | ||||||
|  | 	// after find
 | ||||||
|  | } | ||||||
							
								
								
									
										129
									
								
								clause/clause.go
									
									
									
									
									
								
							
							
						
						
									
										129
									
								
								clause/clause.go
									
									
									
									
									
								
							| @ -51,124 +51,21 @@ type OverrideNameInterface interface { | |||||||
| 	OverrideName() string | 	OverrideName() string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | // Column quote with name
 | ||||||
| // Predefined Clauses
 | type Column struct { | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | 	Table string | ||||||
| 
 | 	Name  string | ||||||
| // Where where clause
 | 	Alias string | ||||||
| type Where struct { | 	Raw   bool | ||||||
| 	AndConditions AddConditions |  | ||||||
| 	ORConditions  []ORConditions |  | ||||||
| 	builders      []Expression |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (where Where) Name() string { | func ToColumns(value ...interface{}) []Column { | ||||||
| 	return "WHERE" | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (where Where) Build(builder Builder) { | // Table quote with name
 | ||||||
| 	var withConditions bool | type Table struct { | ||||||
| 
 | 	Table string | ||||||
| 	if len(where.AndConditions) > 0 { | 	Alias string | ||||||
| 		withConditions = true | 	Raw   bool | ||||||
| 		where.AndConditions.Build(builder) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if len(where.builders) > 0 { |  | ||||||
| 		for _, b := range where.builders { |  | ||||||
| 			if withConditions { |  | ||||||
| 				builder.Write(" AND ") |  | ||||||
| 			} |  | ||||||
| 			withConditions = true |  | ||||||
| 			b.Build(builder) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var singleOrConditions []ORConditions |  | ||||||
| 	for _, or := range where.ORConditions { |  | ||||||
| 		if len(or) == 1 { |  | ||||||
| 			if withConditions { |  | ||||||
| 				builder.Write(" OR ") |  | ||||||
| 				or.Build(builder) |  | ||||||
| 			} else { |  | ||||||
| 				singleOrConditions = append(singleOrConditions, or) |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			withConditions = true |  | ||||||
| 			builder.Write(" AND (") |  | ||||||
| 			or.Build(builder) |  | ||||||
| 			builder.WriteByte(')') |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, or := range singleOrConditions { |  | ||||||
| 		if withConditions { |  | ||||||
| 			builder.Write(" AND ") |  | ||||||
| 			or.Build(builder) |  | ||||||
| 		} else { |  | ||||||
| 			withConditions = true |  | ||||||
| 			or.Build(builder) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if !withConditions { |  | ||||||
| 		builder.Write(" FALSE") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (where Where) MergeExpression(expr Expression) { |  | ||||||
| 	if w, ok := expr.(Where); ok { |  | ||||||
| 		where.AndConditions = append(where.AndConditions, w.AndConditions...) |  | ||||||
| 		where.ORConditions = append(where.ORConditions, w.ORConditions...) |  | ||||||
| 		where.builders = append(where.builders, w.builders...) |  | ||||||
| 	} else { |  | ||||||
| 		where.builders = append(where.builders, expr) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Select select attrs when querying, updating, creating
 |  | ||||||
| type Select struct { |  | ||||||
| 	Omit bool |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Join join clause
 |  | ||||||
| type Join struct { |  | ||||||
| 	Table    string |  | ||||||
| 	Type     string // left join books on
 |  | ||||||
| 	ON       []Expression |  | ||||||
| 	builders []Expression |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (join Join) Build(builder Builder) { |  | ||||||
| 	// TODO
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (join Join) MergeExpression(expr Expression) { |  | ||||||
| 	if j, ok := expr.(Join); ok { |  | ||||||
| 		join.builders = append(join.builders, j.builders...) |  | ||||||
| 	} else { |  | ||||||
| 		join.builders = append(join.builders, expr) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // GroupBy group by clause
 |  | ||||||
| type GroupBy struct { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Having having clause
 |  | ||||||
| type Having struct { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Order order clause
 |  | ||||||
| type Order struct { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Limit limit clause
 |  | ||||||
| type Limit struct { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Offset offset clause
 |  | ||||||
| type Offset struct { |  | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								clause/from.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								clause/from.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | // From from clause
 | ||||||
|  | type From struct { | ||||||
|  | 	Tables []Table | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Name from clause name
 | ||||||
|  | func (From) Name() string { | ||||||
|  | 	return "FROM" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Build build from clause
 | ||||||
|  | func (from From) Build(builder Builder) { | ||||||
|  | 	for idx, table := range from.Tables { | ||||||
|  | 		if idx > 0 { | ||||||
|  | 			builder.WriteByte(',') | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		builder.WriteQuoted(table) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										6
									
								
								clause/group_by.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								clause/group_by.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | // GroupBy group by clause
 | ||||||
|  | type GroupBy struct { | ||||||
|  | 	Having Where | ||||||
|  | } | ||||||
							
								
								
									
										23
									
								
								clause/join.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								clause/join.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | // Join join clause
 | ||||||
|  | type Join struct { | ||||||
|  | 	Table From   // From
 | ||||||
|  | 	Type  string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
 | ||||||
|  | 	Using []Column | ||||||
|  | 	ON    Where | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TODO multiple joins
 | ||||||
|  | 
 | ||||||
|  | func (join Join) Build(builder Builder) { | ||||||
|  | 	// TODO
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (join Join) MergeExpression(expr Expression) { | ||||||
|  | 	// if j, ok := expr.(Join); ok {
 | ||||||
|  | 	// 	join.builders = append(join.builders, j.builders...)
 | ||||||
|  | 	// } else {
 | ||||||
|  | 	// 	join.builders = append(join.builders, expr)
 | ||||||
|  | 	// }
 | ||||||
|  | } | ||||||
							
								
								
									
										6
									
								
								clause/limit.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								clause/limit.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | // Limit limit clause
 | ||||||
|  | type Limit struct { | ||||||
|  | 	Offset uint | ||||||
|  | } | ||||||
							
								
								
									
										4
									
								
								clause/order_by.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								clause/order_by.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | type OrderBy struct { | ||||||
|  | } | ||||||
| @ -2,12 +2,6 @@ package clause | |||||||
| 
 | 
 | ||||||
| import "strings" | import "strings" | ||||||
| 
 | 
 | ||||||
| // Column quote with name
 |  | ||||||
| type Column struct { |  | ||||||
| 	Table string |  | ||||||
| 	Name  string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
| // Query Expressions
 | // Query Expressions
 | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | |||||||
							
								
								
									
										45
									
								
								clause/select.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								clause/select.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,45 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | // Select select attrs when querying, updating, creating
 | ||||||
|  | type Select struct { | ||||||
|  | 	SelectColumns []Column | ||||||
|  | 	OmitColumns   []Column | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SelectInterface select clause interface
 | ||||||
|  | type SelectInterface interface { | ||||||
|  | 	Selects() []Column | ||||||
|  | 	Omits() []Column | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s Select) Selects() []Column { | ||||||
|  | 	return s.SelectColumns | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s Select) Omits() []Column { | ||||||
|  | 	return s.OmitColumns | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s Select) Build(builder Builder) { | ||||||
|  | 	if len(s.SelectColumns) > 0 { | ||||||
|  | 		for idx, column := range s.SelectColumns { | ||||||
|  | 			if idx > 0 { | ||||||
|  | 				builder.WriteByte(',') | ||||||
|  | 			} | ||||||
|  | 			builder.WriteQuoted(column) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		builder.WriteByte('*') | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s Select) MergeExpression(expr Expression) { | ||||||
|  | 	if v, ok := expr.(SelectInterface); ok { | ||||||
|  | 		if len(s.SelectColumns) == 0 { | ||||||
|  | 			s.SelectColumns = v.Selects() | ||||||
|  | 		} | ||||||
|  | 		if len(s.OmitColumns) == 0 { | ||||||
|  | 			s.OmitColumns = v.Omits() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										77
									
								
								clause/where.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								clause/where.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | // Where where clause
 | ||||||
|  | type Where struct { | ||||||
|  | 	AndConditions AddConditions | ||||||
|  | 	ORConditions  []ORConditions | ||||||
|  | 	builders      []Expression | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Name where clause name
 | ||||||
|  | func (where Where) Name() string { | ||||||
|  | 	return "WHERE" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Build build where clause
 | ||||||
|  | func (where Where) Build(builder Builder) { | ||||||
|  | 	var withConditions bool | ||||||
|  | 
 | ||||||
|  | 	if len(where.AndConditions) > 0 { | ||||||
|  | 		withConditions = true | ||||||
|  | 		where.AndConditions.Build(builder) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(where.builders) > 0 { | ||||||
|  | 		for _, b := range where.builders { | ||||||
|  | 			if withConditions { | ||||||
|  | 				builder.Write(" AND ") | ||||||
|  | 			} | ||||||
|  | 			withConditions = true | ||||||
|  | 			b.Build(builder) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var singleOrConditions []ORConditions | ||||||
|  | 	for _, or := range where.ORConditions { | ||||||
|  | 		if len(or) == 1 { | ||||||
|  | 			if withConditions { | ||||||
|  | 				builder.Write(" OR ") | ||||||
|  | 				or.Build(builder) | ||||||
|  | 			} else { | ||||||
|  | 				singleOrConditions = append(singleOrConditions, or) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			withConditions = true | ||||||
|  | 			builder.Write(" AND (") | ||||||
|  | 			or.Build(builder) | ||||||
|  | 			builder.WriteByte(')') | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, or := range singleOrConditions { | ||||||
|  | 		if withConditions { | ||||||
|  | 			builder.Write(" AND ") | ||||||
|  | 			or.Build(builder) | ||||||
|  | 		} else { | ||||||
|  | 			withConditions = true | ||||||
|  | 			or.Build(builder) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !withConditions { | ||||||
|  | 		builder.Write(" FALSE") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MergeExpression merge where clauses
 | ||||||
|  | func (where Where) MergeExpression(expr Expression) { | ||||||
|  | 	if w, ok := expr.(Where); ok { | ||||||
|  | 		where.AndConditions = append(where.AndConditions, w.AndConditions...) | ||||||
|  | 		where.ORConditions = append(where.ORConditions, w.ORConditions...) | ||||||
|  | 		where.builders = append(where.builders, w.builders...) | ||||||
|  | 	} else { | ||||||
|  | 		where.builders = append(where.builders, expr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										4
									
								
								clause/with.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								clause/with.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | package clause | ||||||
|  | 
 | ||||||
|  | type With struct { | ||||||
|  | } | ||||||
| @ -1,7 +1,5 @@ | |||||||
| module github.com/jinzhu/gorm/dialects/mysql | module github.com/jinzhu/gorm/dialects/sqlite | ||||||
| 
 | 
 | ||||||
| go 1.13 | go 1.13 | ||||||
| 
 | 
 | ||||||
| require ( | require github.com/mattn/go-sqlite3 v2.0.3+incompatible | ||||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible |  | ||||||
| ) |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								dialects/sqlite/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								dialects/sqlite/go.sum
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= | ||||||
|  | github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | ||||||
| @ -1,6 +1,7 @@ | |||||||
| package sqlite | package sqlite | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/callbacks" | 	"github.com/jinzhu/gorm/callbacks" | ||||||
| 	_ "github.com/mattn/go-sqlite3" | 	_ "github.com/mattn/go-sqlite3" | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,15 +1,27 @@ | |||||||
| package sqlite_test | package sqlite_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/dialects/sqlite" | ||||||
|  | 	"github.com/jinzhu/gorm/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var DB *gorm.DB | var ( | ||||||
|  | 	DB  *gorm.DB | ||||||
|  | 	err error | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| func TestOpen(t *testing.T) { | func init() { | ||||||
| 	db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) | 	if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil { | ||||||
|  | 		panic(fmt.Sprintf("failed to initialize database, got error %v", err)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSqlite(t *testing.T) { | ||||||
|  | 	tests.RunTestsSuit(t, DB) | ||||||
| } | } | ||||||
|  | |||||||
| @ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { | |||||||
| // First find first record that match given conditions, order by primary key
 | // First find first record that match given conditions, order by primary key
 | ||||||
| func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) | 	tx.Statement.Dest = out | ||||||
|  | 	tx.Limit(1) | ||||||
|  | 	tx.callbacks.Query().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Row() *sql.Row { | func (db *DB) Row() *sql.Row { | ||||||
| 	// TODO
 |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Rows() (*sql.Rows, error) { | func (db *DB) Rows() (*sql.Rows, error) { | ||||||
| 	// TODO
 |  | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @ -2,4 +2,7 @@ module github.com/jinzhu/gorm | |||||||
| 
 | 
 | ||||||
| go 1.13 | go 1.13 | ||||||
| 
 | 
 | ||||||
| require github.com/jinzhu/inflection v1.0.0 | require ( | ||||||
|  | 	github.com/jinzhu/inflection v1.0.0 | ||||||
|  | 	gopkg.in/errgo.v2 v2.1.0 | ||||||
|  | ) | ||||||
|  | |||||||
							
								
								
									
										43
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								gorm.go
									
									
									
									
									
								
							| @ -2,6 +2,7 @@ package gorm | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm/clause" | 	"github.com/jinzhu/gorm/clause" | ||||||
| @ -12,36 +13,28 @@ import ( | |||||||
| // Config GORM config
 | // Config GORM config
 | ||||||
| type Config struct { | type Config struct { | ||||||
| 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | ||||||
| 	// You can cancel it by setting `SkipDefaultTransaction` to true
 | 	// You can disable it by setting `SkipDefaultTransaction` to true
 | ||||||
| 	SkipDefaultTransaction bool // TODO
 | 	SkipDefaultTransaction bool | ||||||
| 
 |  | ||||||
| 	// NamingStrategy tables, columns naming strategy
 | 	// NamingStrategy tables, columns naming strategy
 | ||||||
| 	NamingStrategy schema.Namer | 	NamingStrategy schema.Namer | ||||||
| 
 |  | ||||||
| 	// Logger
 | 	// Logger
 | ||||||
| 	Logger logger.Interface | 	Logger logger.Interface | ||||||
| 
 |  | ||||||
| 	// NowFunc the function to be used when creating a new timestamp
 | 	// NowFunc the function to be used when creating a new timestamp
 | ||||||
| 	NowFunc func() time.Time | 	NowFunc func() time.Time | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Dialector GORM database dialector
 |  | ||||||
| type Dialector interface { |  | ||||||
| 	Initialize(*DB) error |  | ||||||
| 	Migrator() Migrator |  | ||||||
| 	BindVar(stmt Statement, v interface{}) string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // DB GORM DB definition
 | // DB GORM DB definition
 | ||||||
| type DB struct { | type DB struct { | ||||||
| 	*Config | 	*Config | ||||||
| 	Dialector | 	Dialector | ||||||
| 	Instance | 	Instance | ||||||
| 	clone     bool | 	DB         CommonDB | ||||||
| 	callbacks *callbacks | 	clone      bool | ||||||
|  | 	callbacks  *callbacks | ||||||
|  | 	cacheStore *sync.Map | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Session session config when create new session
 | // Session session config when create session with Session() method
 | ||||||
| type Session struct { | type Session struct { | ||||||
| 	Context context.Context | 	Context context.Context | ||||||
| 	Logger  logger.Interface | 	Logger  logger.Interface | ||||||
| @ -67,10 +60,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	db = &DB{ | 	db = &DB{ | ||||||
| 		Config:    config, | 		Config:     config, | ||||||
| 		Dialector: dialector, | 		Dialector:  dialector, | ||||||
| 		clone:     true, | 		clone:      true, | ||||||
| 		callbacks: InitializeCallbacks(), | 		callbacks:  InitializeCallbacks(), | ||||||
|  | 		cacheStore: &sync.Map{}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if dialector != nil { | 	if dialector != nil { | ||||||
| @ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) { | |||||||
| 	return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) | 	return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Close() error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Set store value with key into current db instance's context
 | // Set store value with key into current db instance's context
 | ||||||
| func (db *DB) Set(key string, value interface{}) *DB { | func (db *DB) Set(key string, value interface{}) *DB { | ||||||
| 	tx := db.getInstance() | 	tx := db.getInstance() | ||||||
| @ -145,12 +135,15 @@ func (db *DB) getInstance() *DB { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return &DB{ | 		return &DB{ | ||||||
| 			Config:    db.Config, |  | ||||||
| 			Dialector: db.Dialector, |  | ||||||
| 			Instance: Instance{ | 			Instance: Instance{ | ||||||
| 				Context:   ctx, | 				Context:   ctx, | ||||||
| 				Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, | 				Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, | ||||||
| 			}, | 			}, | ||||||
|  | 			Config:     db.Config, | ||||||
|  | 			Dialector:  db.Dialector, | ||||||
|  | 			DB:         db.DB, | ||||||
|  | 			callbacks:  db.callbacks, | ||||||
|  | 			cacheStore: db.cacheStore, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										21
									
								
								interfaces.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								interfaces.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | package gorm | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Dialector GORM database dialector
 | ||||||
|  | type Dialector interface { | ||||||
|  | 	Initialize(*DB) error | ||||||
|  | 	Migrator() Migrator | ||||||
|  | 	BindVar(stmt Statement, v interface{}) string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CommonDB common db interface
 | ||||||
|  | type CommonDB interface { | ||||||
|  | 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||||
|  | 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||||
|  | 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||||
|  | 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||||
|  | } | ||||||
| @ -1,6 +1,7 @@ | |||||||
| package schema | package schema | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"go/ast" | 	"go/ast" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| @ -9,6 +10,9 @@ import ( | |||||||
| 	"github.com/jinzhu/gorm/logger" | 	"github.com/jinzhu/gorm/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // ErrUnsupportedDataType unsupported data type
 | ||||||
|  | var ErrUnsupportedDataType = errors.New("unsupported data type") | ||||||
|  | 
 | ||||||
| type Schema struct { | type Schema struct { | ||||||
| 	Name                    string | 	Name                    string | ||||||
| 	ModelType               reflect.Type | 	ModelType               reflect.Type | ||||||
| @ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | |||||||
| 
 | 
 | ||||||
| 	if modelType.Kind() != reflect.Struct { | 	if modelType.Kind() != reflect.Struct { | ||||||
| 		if modelType.PkgPath() == "" { | 		if modelType.PkgPath() == "" { | ||||||
| 			return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) | 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||||
| 		} | 		} | ||||||
| 		return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) | 		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if v, ok := cacheStore.Load(modelType); ok { | 	if v, ok := cacheStore.Load(modelType); ok { | ||||||
| @ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, field := range schema.Fields { | 	for _, field := range schema.Fields { | ||||||
| 		if field.DBName == "" { | 		if field.DBName == "" && field.DataType != "" { | ||||||
| 			field.DBName = namer.ColumnName(schema.Table, field.Name) | 			field.DBName = namer.ColumnName(schema.Table, field.Name) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -2,24 +2,16 @@ package schema_test | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm/schema" | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | 	"github.com/jinzhu/gorm/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | ||||||
| 	t.Run("CheckSchema/"+s.Name, func(t *testing.T) { | 	t.Run("CheckSchema/"+s.Name, func(t *testing.T) { | ||||||
| 		equalFieldNames := []string{"Name", "Table"} | 		tests.AssertEqual(t, s, v, "Name", "Table") | ||||||
| 
 |  | ||||||
| 		for _, name := range equalFieldNames { |  | ||||||
| 			got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() |  | ||||||
| 			expects := reflect.ValueOf(v).FieldByName(name).Interface() |  | ||||||
| 			if !reflect.DeepEqual(got, expects) { |  | ||||||
| 				t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		for idx, field := range primaryFields { | 		for idx, field := range primaryFields { | ||||||
| 			var found bool | 			var found bool | ||||||
| @ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* | |||||||
| 		if parsedField, ok := s.FieldsByName[f.Name]; !ok { | 		if parsedField, ok := s.FieldsByName[f.Name]; !ok { | ||||||
| 			t.Errorf("schema %v failed to look up field with name %v", s, f.Name) | 			t.Errorf("schema %v failed to look up field with name %v", s, f.Name) | ||||||
| 		} else { | 		} else { | ||||||
| 			equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} | 			tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") | ||||||
| 
 |  | ||||||
| 			for _, name := range equalFieldNames { |  | ||||||
| 				got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() |  | ||||||
| 				expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() |  | ||||||
| 				if !reflect.DeepEqual(got, expects) { |  | ||||||
| 					t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { | 			if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { | ||||||
| 				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) | 				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) | ||||||
|  | |||||||
							
								
								
									
										31
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								statement.go
									
									
									
									
									
								
							| @ -10,6 +10,7 @@ import ( | |||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm/clause" | 	"github.com/jinzhu/gorm/clause" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Instance db instance
 | // Instance db instance
 | ||||||
| @ -37,6 +38,7 @@ type Statement struct { | |||||||
| 	Clauses  map[string]clause.Clause | 	Clauses  map[string]clause.Clause | ||||||
| 	Settings sync.Map | 	Settings sync.Map | ||||||
| 	DB       *DB | 	DB       *DB | ||||||
|  | 	Schema   *schema.Schema | ||||||
| 
 | 
 | ||||||
| 	// SQL Builder
 | 	// SQL Builder
 | ||||||
| 	SQL       strings.Builder | 	SQL       strings.Builder | ||||||
| @ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Quote returns quoted value
 | // Quote returns quoted value
 | ||||||
| func (stmt Statement) Quote(field interface{}) (str string) { | func (stmt Statement) Quote(field interface{}) string { | ||||||
| 	// FIXME
 | 	var str strings.Builder | ||||||
| 	return fmt.Sprint(field) | 
 | ||||||
|  | 	switch v := field.(type) { | ||||||
|  | 	case clause.Table: | ||||||
|  | 		str.WriteString(v.Table) | ||||||
|  | 		if v.Alias != "" { | ||||||
|  | 			str.WriteString(" AS ") | ||||||
|  | 			str.WriteString(v.Alias) | ||||||
|  | 		} | ||||||
|  | 	case clause.Column: | ||||||
|  | 		if v.Table != "" { | ||||||
|  | 			str.WriteString(v.Table) | ||||||
|  | 			str.WriteByte('.') | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		str.WriteString(v.Name) | ||||||
|  | 		if v.Alias != "" { | ||||||
|  | 			str.WriteString(" AS ") | ||||||
|  | 			str.WriteString(v.Alias) | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		fmt.Sprint(field) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return str.String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Write write string
 | // Write write string
 | ||||||
|  | |||||||
| @ -1 +0,0 @@ | |||||||
| package tests |  | ||||||
							
								
								
									
										42
									
								
								tests/tests.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/tests.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | |||||||
|  | package tests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func Now() *time.Time { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return &now | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func RunTestsSuit(t *testing.T, db *gorm.DB) { | ||||||
|  | 	TestCreate(t, db) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCreate(t *testing.T, db *gorm.DB) { | ||||||
|  | 	t.Run("Create", func(t *testing.T) { | ||||||
|  | 		var user = User{ | ||||||
|  | 			Name:     "create", | ||||||
|  | 			Age:      18, | ||||||
|  | 			Birthday: Now(), | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := db.Create(&user).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when create: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if user.ID == 0 { | ||||||
|  | 			t.Errorf("user's primary key should has value after create, got : %v", user.ID) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var newUser User | ||||||
|  | 		if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when query: %v", err) | ||||||
|  | 		} else { | ||||||
|  | 			AssertEqual(t, newUser, user, "Name", "Age", "Birthday") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										19
									
								
								tests/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								tests/utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | |||||||
|  | package tests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func AssertEqual(t *testing.T, r, e interface{}, names ...string) { | ||||||
|  | 	for _, name := range names { | ||||||
|  | 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||||
|  | 		expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||||
|  | 
 | ||||||
|  | 		if !reflect.DeepEqual(got, expects) { | ||||||
|  | 			t.Run(name, func(t *testing.T) { | ||||||
|  | 				t.Errorf("expects: %v, got %v", expects, got) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu