Add NamedArg support
This commit is contained in:
		
							parent
							
								
									bc3728a18f
								
							
						
					
					
						commit
						bba569af2b
					
				| @ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. | ||||
| * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point | ||||
| * Context, Prepared Statment Mode, DryRun Mode | ||||
| * Batch Insert, FindInBatches, Find To Map | ||||
| * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints | ||||
| * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg | ||||
| * Composite Primary Key | ||||
| * Auto Migrations | ||||
| * Logger | ||||
|  | ||||
| @ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) { | ||||
| 	if !stmt.DB.DryRun { | ||||
| 		stmt.SQL.Reset() | ||||
| 		stmt.Vars = nil | ||||
| 		stmt.NamedVars = nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) { | ||||
| func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.SQL = strings.Builder{} | ||||
| 
 | ||||
| 	if strings.Contains(sql, "@") { | ||||
| 		clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||
| 	} else { | ||||
| 		clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package clause | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"reflect" | ||||
| ) | ||||
| @ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // NamedExpr raw expression for named expr
 | ||||
| type NamedExpr struct { | ||||
| 	SQL  string | ||||
| 	Vars []interface{} | ||||
| } | ||||
| 
 | ||||
| // Build build raw expression
 | ||||
| func (expr NamedExpr) Build(builder Builder) { | ||||
| 	var ( | ||||
| 		idx      int | ||||
| 		inName   bool | ||||
| 		namedMap = make(map[string]interface{}, len(expr.Vars)) | ||||
| 	) | ||||
| 
 | ||||
| 	for _, v := range expr.Vars { | ||||
| 		switch value := v.(type) { | ||||
| 		case sql.NamedArg: | ||||
| 			namedMap[value.Name] = value.Value | ||||
| 		case map[string]interface{}: | ||||
| 			for k, v := range value { | ||||
| 				namedMap[k] = v | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	name := make([]byte, 0, 10) | ||||
| 
 | ||||
| 	for _, v := range []byte(expr.SQL) { | ||||
| 		if v == '@' && !inName { | ||||
| 			inName = true | ||||
| 			name = []byte{} | ||||
| 		} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { | ||||
| 			if inName { | ||||
| 				if nv, ok := namedMap[string(name)]; ok { | ||||
| 					builder.AddVar(builder, nv) | ||||
| 				} else { | ||||
| 					builder.WriteByte('@') | ||||
| 					builder.WriteString(string(name)) | ||||
| 				} | ||||
| 				inName = false | ||||
| 			} | ||||
| 
 | ||||
| 			builder.WriteByte(v) | ||||
| 		} else if v == '?' { | ||||
| 			builder.AddVar(builder, expr.Vars[idx]) | ||||
| 			idx++ | ||||
| 		} else if inName { | ||||
| 			name = append(name, v) | ||||
| 		} else { | ||||
| 			builder.WriteByte(v) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if inName { | ||||
| 		builder.AddVar(builder, namedMap[string(name)]) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // IN Whether a value is within a set of values
 | ||||
| type IN struct { | ||||
| 	Column interface{} | ||||
|  | ||||
| @ -1,7 +1,9 @@ | ||||
| package clause_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| @ -33,3 +35,51 @@ func TestExpr(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNamedExpr(t *testing.T) { | ||||
| 	results := []struct { | ||||
| 		SQL          string | ||||
| 		Result       string | ||||
| 		Vars         []interface{} | ||||
| 		ExpectedVars []interface{} | ||||
| 	}{{ | ||||
| 		SQL:    "create table ? (? ?, ? ?)", | ||||
| 		Vars:   []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, | ||||
| 		Result: "create table `users` (`id` int, `name` text)", | ||||
| 	}, { | ||||
| 		SQL:          "name1 = @name AND name2 = @name", | ||||
| 		Vars:         []interface{}{sql.Named("name", "jinzhu")}, | ||||
| 		Result:       "name1 = ? AND name2 = ?", | ||||
| 		ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, | ||||
| 	}, { | ||||
| 		SQL:          "name1 = @name1 AND name2 = @name2 AND name3 = @name1", | ||||
| 		Vars:         []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, | ||||
| 		Result:       "name1 = ? AND name2 = ? AND name3 = ?", | ||||
| 		ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, | ||||
| 	}, { | ||||
| 		SQL:          "name1 = @name1 AND name2 = @name2 AND name3 = @name1", | ||||
| 		Vars:         []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, | ||||
| 		Result:       "name1 = ? AND name2 = ? AND name3 = ?", | ||||
| 		ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, | ||||
| 	}, { | ||||
| 		SQL:          "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", | ||||
| 		Vars:         []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, | ||||
| 		Result:       "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", | ||||
| 		ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, | ||||
| 	}} | ||||
| 
 | ||||
| 	for idx, result := range results { | ||||
| 		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { | ||||
| 			user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||
| 			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||
| 			clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) | ||||
| 			if stmt.SQL.String() != result.Result { | ||||
| 				t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) | ||||
| 			} | ||||
| 
 | ||||
| 			if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { | ||||
| 				t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB { | ||||
| func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.SQL = strings.Builder{} | ||||
| 
 | ||||
| 	if strings.Contains(sql, "@") { | ||||
| 		clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||
| 	} else { | ||||
| 		clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||
| 	} | ||||
| 
 | ||||
| 	tx.callbacks.Raw().Execute(tx) | ||||
| 	return | ||||
| } | ||||
|  | ||||
							
								
								
									
										23
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								statement.go
									
									
									
									
									
								
							| @ -38,7 +38,6 @@ type Statement struct { | ||||
| 	UpdatingColumn       bool | ||||
| 	SQL                  strings.Builder | ||||
| 	Vars                 []interface{} | ||||
| 	NamedVars            []sql.NamedArg | ||||
| 	CurDestIndex         int | ||||
| 	attrs                []interface{} | ||||
| 	assigns              []interface{} | ||||
| @ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 
 | ||||
| 		switch v := v.(type) { | ||||
| 		case sql.NamedArg: | ||||
| 			if len(v.Name) > 0 { | ||||
| 				stmt.NamedVars = append(stmt.NamedVars, v) | ||||
| 				writer.WriteByte('@') | ||||
| 				writer.WriteString(v.Name) | ||||
| 			} else { | ||||
| 			stmt.Vars = append(stmt.Vars, v.Value) | ||||
| 				stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value) | ||||
| 			} | ||||
| 		case clause.Column, clause.Table: | ||||
| 			stmt.QuoteTo(writer, v) | ||||
| 		case clause.Expr: | ||||
| @ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { | ||||
| 
 | ||||
| // BuildCondition build condition
 | ||||
| func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { | ||||
| 	if sql, ok := query.(string); ok { | ||||
| 	if s, ok := query.(string); ok { | ||||
| 		// if it is a number, then treats it as primary key
 | ||||
| 		if _, err := strconv.Atoi(sql); err != nil { | ||||
| 			if sql == "" && len(args) == 0 { | ||||
| 		if _, err := strconv.Atoi(s); err != nil { | ||||
| 			if s == "" && len(args) == 0 { | ||||
| 				return | ||||
| 			} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { | ||||
| 			} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { | ||||
| 				// looks like a where condition
 | ||||
| 				return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} | ||||
| 				return []clause.Expression{clause.Expr{SQL: s, Vars: args}} | ||||
| 			} else if len(args) > 0 && strings.Contains(s, "@") { | ||||
| 				// looks like a named query
 | ||||
| 				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} | ||||
| 			} else if len(args) == 1 { | ||||
| 				return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} | ||||
| 				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										57
									
								
								tests/named_argument_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								tests/named_argument_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestNamedArg(t *testing.T) { | ||||
| 	type NamedUser struct { | ||||
| 		gorm.Model | ||||
| 		Name1 string | ||||
| 		Name2 string | ||||
| 		Name3 string | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Migrator().DropTable(&NamedUser{}) | ||||
| 	DB.AutoMigrate(&NamedUser{}) | ||||
| 
 | ||||
| 	namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} | ||||
| 	DB.Create(&namedUser) | ||||
| 
 | ||||
| 	var result NamedUser | ||||
| 	DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) | ||||
| 
 | ||||
| 	AssertEqual(t, result, namedUser) | ||||
| 
 | ||||
| 	var result2 NamedUser | ||||
| 	DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) | ||||
| 
 | ||||
| 	AssertEqual(t, result2, namedUser) | ||||
| 
 | ||||
| 	var result3 NamedUser | ||||
| 	DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) | ||||
| 
 | ||||
| 	AssertEqual(t, result3, namedUser) | ||||
| 
 | ||||
| 	var result4 NamedUser | ||||
| 	if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { | ||||
| 		t.Errorf("failed to update with named arg") | ||||
| 	} | ||||
| 
 | ||||
| 	AssertEqual(t, result4, namedUser) | ||||
| 
 | ||||
| 	if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { | ||||
| 		t.Errorf("failed to update with named arg") | ||||
| 	} | ||||
| 
 | ||||
| 	var result5 NamedUser | ||||
| 	if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil { | ||||
| 		t.Errorf("failed to update with named arg") | ||||
| 	} | ||||
| 
 | ||||
| 	AssertEqual(t, result4, namedUser) | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu