Add Raw, Row, Rows
This commit is contained in:
		
							parent
							
								
									fab7d96da5
								
							
						
					
					
						commit
						215f5e7765
					
				| @ -38,4 +38,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { | ||||
| 	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) | ||||
| 	updateCallback.Register("gorm:after_update", AfterUpdate) | ||||
| 	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||
| 
 | ||||
| 	db.Callback().Row().Register("gorm:raw", RowQuery) | ||||
| 	db.Callback().Raw().Register("gorm:raw", RawExec) | ||||
| } | ||||
|  | ||||
							
								
								
									
										11
									
								
								callbacks/raw.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								callbacks/raw.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import "github.com/jinzhu/gorm" | ||||
| 
 | ||||
| func RawExec(db *gorm.DB) { | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	db.RowsAffected, _ = result.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		db.AddError(err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										19
									
								
								callbacks/row.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								callbacks/row.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| ) | ||||
| 
 | ||||
| func RowQuery(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Select{}) | ||||
| 	db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 
 | ||||
| 	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 
 | ||||
| 	if _, ok := db.Get("rows"); ok { | ||||
| 		db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	} else { | ||||
| 		db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	} | ||||
| } | ||||
| @ -222,5 +222,8 @@ func (db *DB) Unscoped() (tx *DB) { | ||||
| 
 | ||||
| func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	stmt := tx.Statement | ||||
| 	stmt.SQL = strings.Builder{} | ||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(stmt) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:        db, | ||||
| 		Dialector: dialector, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
|  | ||||
| @ -29,7 +29,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:        db, | ||||
| 		Dialector: dialector, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
|  | ||||
| @ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:        db, | ||||
| 		Dialector: dialector, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
|  | ||||
| @ -27,7 +27,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:        db, | ||||
| 		Dialector: dialector, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
|  | ||||
| @ -22,6 +22,10 @@ func init() { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSqlite(t *testing.T) { | ||||
| func TestCURD(t *testing.T) { | ||||
| 	tests.RunTestsSuit(t, DB) | ||||
| } | ||||
| 
 | ||||
| func TestMigrate(t *testing.T) { | ||||
| 	tests.TestMigrate(t, DB) | ||||
| } | ||||
|  | ||||
| @ -108,11 +108,15 @@ func (db *DB) Count(value interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Row() *sql.Row { | ||||
| 	return nil | ||||
| 	tx := db.getInstance() | ||||
| 	tx.callbacks.Row().Execute(tx) | ||||
| 	return tx.Statement.Dest.(*sql.Row) | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Rows() (*sql.Rows, error) { | ||||
| 	return nil, nil | ||||
| 	tx := db.Set("rows", true) | ||||
| 	tx.callbacks.Row().Execute(tx) | ||||
| 	return tx.Statement.Dest.(*sql.Rows), tx.Error | ||||
| } | ||||
| 
 | ||||
| // Scan scan value to a struct
 | ||||
| @ -162,5 +166,6 @@ func (db *DB) Rollback() (tx *DB) { | ||||
| 
 | ||||
| func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.callbacks.Raw().Execute(tx) | ||||
| 	return | ||||
| } | ||||
|  | ||||
							
								
								
									
										5
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								gorm.go
									
									
									
									
									
								
							| @ -138,6 +138,11 @@ func (db *DB) Callback() *callbacks { | ||||
| 	return db.callbacks | ||||
| } | ||||
| 
 | ||||
| // AutoMigrate run auto migration for given models
 | ||||
| func (db *DB) AutoMigrate(dst ...interface{}) error { | ||||
| 	return db.Migrator().AutoMigrate(dst...) | ||||
| } | ||||
| 
 | ||||
| func (db *DB) getInstance() *DB { | ||||
| 	if db.clone { | ||||
| 		ctx := db.Instance.Context | ||||
|  | ||||
| @ -265,8 +265,15 @@ func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { | ||||
| 	return nil, gorm.ErrNotImplemented | ||||
| func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { | ||||
| 	err = m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||
| 		rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() | ||||
| 		if err == nil { | ||||
| 			columnTypes, err = rows.ColumnTypes() | ||||
| 		} | ||||
| 		return err | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (m Migrator) CreateView(name string, option gorm.ViewOption) error { | ||||
|  | ||||
| @ -17,9 +17,12 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check { | ||||
| 	for _, field := range schema.FieldsByDBName { | ||||
| 		if chk := field.TagSettings["CHECK"]; chk != "" { | ||||
| 			names := strings.Split(chk, ",") | ||||
| 			if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) { | ||||
| 			if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) { | ||||
| 				checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} | ||||
| 			} else { | ||||
| 				if names[0] == "" { | ||||
| 					chk = strings.Join(names[1:], ",") | ||||
| 				} | ||||
| 				name := schema.namer.CheckerName(schema.Table, field.DBName) | ||||
| 				checks[name] = Check{Name: name, Constraint: chk, Field: field} | ||||
| 			} | ||||
|  | ||||
							
								
								
									
										55
									
								
								schema/check_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								schema/check_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,55 @@ | ||||
| package schema_test | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| type UserCheck struct { | ||||
| 	Name  string `gorm:"check:name_checker,name <> 'jinzhu'"` | ||||
| 	Name2 string `gorm:"check:name <> 'jinzhu'"` | ||||
| 	Name3 string `gorm:"check:,name <> 'jinzhu'"` | ||||
| } | ||||
| 
 | ||||
| func TestParseCheck(t *testing.T) { | ||||
| 	user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to parse user check, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	results := map[string]schema.Check{ | ||||
| 		"name_checker": { | ||||
| 			Name:       "name_checker", | ||||
| 			Constraint: "name <> 'jinzhu'", | ||||
| 		}, | ||||
| 		"chk_user_checks_name2": { | ||||
| 			Name:       "chk_user_checks_name2", | ||||
| 			Constraint: "name <> 'jinzhu'", | ||||
| 		}, | ||||
| 		"chk_user_checks_name3": { | ||||
| 			Name:       "chk_user_checks_name3", | ||||
| 			Constraint: "name <> 'jinzhu'", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	checks := user.ParseCheckConstraints() | ||||
| 
 | ||||
| 	for k, result := range results { | ||||
| 		v, ok := checks[k] | ||||
| 		if !ok { | ||||
| 			t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) | ||||
| 		} | ||||
| 
 | ||||
| 		for _, name := range []string{"Name", "Constraint"} { | ||||
| 			if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { | ||||
| 				t.Errorf( | ||||
| 					"check %v %v should equal, expects %v, got %v", | ||||
| 					k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), | ||||
| 				) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -21,7 +21,7 @@ type UserIndex struct { | ||||
| func TestParseIndex(t *testing.T) { | ||||
| 	user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to parse user index index, got error %v", err) | ||||
| 		t.Fatalf("failed to parse user index, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	results := map[string]schema.Index{ | ||||
|  | ||||
| @ -317,7 +317,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { | ||||
| 		settings = ParseTagSetting(str, ",") | ||||
| 	) | ||||
| 
 | ||||
| 	if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) { | ||||
| 	if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { | ||||
| 		name = str[0:idx] | ||||
| 	} else { | ||||
| 		name = rel.Schema.namer.RelationshipFKName(*rel) | ||||
| @ -339,5 +339,9 @@ func (rel *Relationship) ParseConstraint() *Constraint { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if constraint.ReferenceSchema == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return &constraint | ||||
| } | ||||
|  | ||||
							
								
								
									
										19
									
								
								tests/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								tests/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func TestMigrate(t *testing.T, db *gorm.DB) { | ||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} | ||||
| 
 | ||||
| 	db.AutoMigrate(allModels...) | ||||
| 
 | ||||
| 	for _, m := range allModels { | ||||
| 		if !db.Migrator().HasTable(m) { | ||||
| 			t.Errorf("Failed to create table for %+v", m) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu