Add NamedArg support
This commit is contained in:
		
							parent
							
								
									bc3728a18f
								
							
						
					
					
						commit
						bba569af2b
					
				| @ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. | |||||||
| * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point | * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point | ||||||
| * Context, Prepared Statment Mode, DryRun Mode | * Context, Prepared Statment Mode, DryRun Mode | ||||||
| * Batch Insert, FindInBatches, Find To Map | * Batch Insert, FindInBatches, Find To Map | ||||||
| * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints | * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg | ||||||
| * Composite Primary Key | * Composite Primary Key | ||||||
| * Auto Migrations | * Auto Migrations | ||||||
| * Logger | * Logger | ||||||
|  | |||||||
| @ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) { | |||||||
| 	if !stmt.DB.DryRun { | 	if !stmt.DB.DryRun { | ||||||
| 		stmt.SQL.Reset() | 		stmt.SQL.Reset() | ||||||
| 		stmt.Vars = nil | 		stmt.Vars = nil | ||||||
| 		stmt.NamedVars = nil |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) { | |||||||
| func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { | func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.SQL = strings.Builder{} | 	tx.Statement.SQL = strings.Builder{} | ||||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | 
 | ||||||
|  | 	if strings.Contains(sql, "@") { | ||||||
|  | 		clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||||
|  | 	} else { | ||||||
|  | 		clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||||
|  | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package clause | package clause | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| ) | ) | ||||||
| @ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NamedExpr raw expression for named expr
 | ||||||
|  | type NamedExpr struct { | ||||||
|  | 	SQL  string | ||||||
|  | 	Vars []interface{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Build build raw expression
 | ||||||
|  | func (expr NamedExpr) Build(builder Builder) { | ||||||
|  | 	var ( | ||||||
|  | 		idx      int | ||||||
|  | 		inName   bool | ||||||
|  | 		namedMap = make(map[string]interface{}, len(expr.Vars)) | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	for _, v := range expr.Vars { | ||||||
|  | 		switch value := v.(type) { | ||||||
|  | 		case sql.NamedArg: | ||||||
|  | 			namedMap[value.Name] = value.Value | ||||||
|  | 		case map[string]interface{}: | ||||||
|  | 			for k, v := range value { | ||||||
|  | 				namedMap[k] = v | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	name := make([]byte, 0, 10) | ||||||
|  | 
 | ||||||
|  | 	for _, v := range []byte(expr.SQL) { | ||||||
|  | 		if v == '@' && !inName { | ||||||
|  | 			inName = true | ||||||
|  | 			name = []byte{} | ||||||
|  | 		} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { | ||||||
|  | 			if inName { | ||||||
|  | 				if nv, ok := namedMap[string(name)]; ok { | ||||||
|  | 					builder.AddVar(builder, nv) | ||||||
|  | 				} else { | ||||||
|  | 					builder.WriteByte('@') | ||||||
|  | 					builder.WriteString(string(name)) | ||||||
|  | 				} | ||||||
|  | 				inName = false | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			builder.WriteByte(v) | ||||||
|  | 		} else if v == '?' { | ||||||
|  | 			builder.AddVar(builder, expr.Vars[idx]) | ||||||
|  | 			idx++ | ||||||
|  | 		} else if inName { | ||||||
|  | 			name = append(name, v) | ||||||
|  | 		} else { | ||||||
|  | 			builder.WriteByte(v) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if inName { | ||||||
|  | 		builder.AddVar(builder, namedMap[string(name)]) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // IN Whether a value is within a set of values
 | // IN Whether a value is within a set of values
 | ||||||
| type IN struct { | type IN struct { | ||||||
| 	Column interface{} | 	Column interface{} | ||||||
|  | |||||||
| @ -1,7 +1,9 @@ | |||||||
| package clause_test | package clause_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"reflect" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| @ -33,3 +35,51 @@ func TestExpr(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestNamedExpr(t *testing.T) { | ||||||
|  | 	results := []struct { | ||||||
|  | 		SQL          string | ||||||
|  | 		Result       string | ||||||
|  | 		Vars         []interface{} | ||||||
|  | 		ExpectedVars []interface{} | ||||||
|  | 	}{{ | ||||||
|  | 		SQL:    "create table ? (? ?, ? ?)", | ||||||
|  | 		Vars:   []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, | ||||||
|  | 		Result: "create table `users` (`id` int, `name` text)", | ||||||
|  | 	}, { | ||||||
|  | 		SQL:          "name1 = @name AND name2 = @name", | ||||||
|  | 		Vars:         []interface{}{sql.Named("name", "jinzhu")}, | ||||||
|  | 		Result:       "name1 = ? AND name2 = ?", | ||||||
|  | 		ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, | ||||||
|  | 	}, { | ||||||
|  | 		SQL:          "name1 = @name1 AND name2 = @name2 AND name3 = @name1", | ||||||
|  | 		Vars:         []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, | ||||||
|  | 		Result:       "name1 = ? AND name2 = ? AND name3 = ?", | ||||||
|  | 		ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, | ||||||
|  | 	}, { | ||||||
|  | 		SQL:          "name1 = @name1 AND name2 = @name2 AND name3 = @name1", | ||||||
|  | 		Vars:         []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, | ||||||
|  | 		Result:       "name1 = ? AND name2 = ? AND name3 = ?", | ||||||
|  | 		ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, | ||||||
|  | 	}, { | ||||||
|  | 		SQL:          "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", | ||||||
|  | 		Vars:         []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, | ||||||
|  | 		Result:       "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", | ||||||
|  | 		ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, | ||||||
|  | 	}} | ||||||
|  | 
 | ||||||
|  | 	for idx, result := range results { | ||||||
|  | 		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { | ||||||
|  | 			user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||||
|  | 			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||||
|  | 			clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) | ||||||
|  | 			if stmt.SQL.String() != result.Result { | ||||||
|  | 				t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { | ||||||
|  | 				t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB { | |||||||
| func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.SQL = strings.Builder{} | 	tx.Statement.SQL = strings.Builder{} | ||||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | 
 | ||||||
|  | 	if strings.Contains(sql, "@") { | ||||||
|  | 		clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||||
|  | 	} else { | ||||||
|  | 		clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	tx.callbacks.Raw().Execute(tx) | 	tx.callbacks.Raw().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										25
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								statement.go
									
									
									
									
									
								
							| @ -38,7 +38,6 @@ type Statement struct { | |||||||
| 	UpdatingColumn       bool | 	UpdatingColumn       bool | ||||||
| 	SQL                  strings.Builder | 	SQL                  strings.Builder | ||||||
| 	Vars                 []interface{} | 	Vars                 []interface{} | ||||||
| 	NamedVars            []sql.NamedArg |  | ||||||
| 	CurDestIndex         int | 	CurDestIndex         int | ||||||
| 	attrs                []interface{} | 	attrs                []interface{} | ||||||
| 	assigns              []interface{} | 	assigns              []interface{} | ||||||
| @ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | |||||||
| 
 | 
 | ||||||
| 		switch v := v.(type) { | 		switch v := v.(type) { | ||||||
| 		case sql.NamedArg: | 		case sql.NamedArg: | ||||||
| 			if len(v.Name) > 0 { | 			stmt.Vars = append(stmt.Vars, v.Value) | ||||||
| 				stmt.NamedVars = append(stmt.NamedVars, v) |  | ||||||
| 				writer.WriteByte('@') |  | ||||||
| 				writer.WriteString(v.Name) |  | ||||||
| 			} else { |  | ||||||
| 				stmt.Vars = append(stmt.Vars, v.Value) |  | ||||||
| 				stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value) |  | ||||||
| 			} |  | ||||||
| 		case clause.Column, clause.Table: | 		case clause.Column, clause.Table: | ||||||
| 			stmt.QuoteTo(writer, v) | 			stmt.QuoteTo(writer, v) | ||||||
| 		case clause.Expr: | 		case clause.Expr: | ||||||
| @ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { | |||||||
| 
 | 
 | ||||||
| // BuildCondition build condition
 | // BuildCondition build condition
 | ||||||
| func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { | func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { | ||||||
| 	if sql, ok := query.(string); ok { | 	if s, ok := query.(string); ok { | ||||||
| 		// if it is a number, then treats it as primary key
 | 		// if it is a number, then treats it as primary key
 | ||||||
| 		if _, err := strconv.Atoi(sql); err != nil { | 		if _, err := strconv.Atoi(s); err != nil { | ||||||
| 			if sql == "" && len(args) == 0 { | 			if s == "" && len(args) == 0 { | ||||||
| 				return | 				return | ||||||
| 			} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { | 			} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { | ||||||
| 				// looks like a where condition
 | 				// looks like a where condition
 | ||||||
| 				return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} | 				return []clause.Expression{clause.Expr{SQL: s, Vars: args}} | ||||||
|  | 			} else if len(args) > 0 && strings.Contains(s, "@") { | ||||||
|  | 				// looks like a named query
 | ||||||
|  | 				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} | ||||||
| 			} else if len(args) == 1 { | 			} else if len(args) == 1 { | ||||||
| 				return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} | 				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										57
									
								
								tests/named_argument_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								tests/named_argument_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	. "gorm.io/gorm/utils/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestNamedArg(t *testing.T) { | ||||||
|  | 	type NamedUser struct { | ||||||
|  | 		gorm.Model | ||||||
|  | 		Name1 string | ||||||
|  | 		Name2 string | ||||||
|  | 		Name3 string | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Migrator().DropTable(&NamedUser{}) | ||||||
|  | 	DB.AutoMigrate(&NamedUser{}) | ||||||
|  | 
 | ||||||
|  | 	namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} | ||||||
|  | 	DB.Create(&namedUser) | ||||||
|  | 
 | ||||||
|  | 	var result NamedUser | ||||||
|  | 	DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) | ||||||
|  | 
 | ||||||
|  | 	AssertEqual(t, result, namedUser) | ||||||
|  | 
 | ||||||
|  | 	var result2 NamedUser | ||||||
|  | 	DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) | ||||||
|  | 
 | ||||||
|  | 	AssertEqual(t, result2, namedUser) | ||||||
|  | 
 | ||||||
|  | 	var result3 NamedUser | ||||||
|  | 	DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) | ||||||
|  | 
 | ||||||
|  | 	AssertEqual(t, result3, namedUser) | ||||||
|  | 
 | ||||||
|  | 	var result4 NamedUser | ||||||
|  | 	if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { | ||||||
|  | 		t.Errorf("failed to update with named arg") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	AssertEqual(t, result4, namedUser) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { | ||||||
|  | 		t.Errorf("failed to update with named arg") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var result5 NamedUser | ||||||
|  | 	if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil { | ||||||
|  | 		t.Errorf("failed to update with named arg") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	AssertEqual(t, result4, namedUser) | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu