Test SubQuery
This commit is contained in:
		
							parent
							
								
									1559fe24e5
								
							
						
					
					
						commit
						4e147e1256
					
				| @ -80,7 +80,7 @@ func (p *processor) Execute(db *DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if stmt.Model != nil { | ||||
| 			if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { | ||||
| 			if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -63,36 +63,38 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 					db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") | ||||
| 				} | ||||
| 
 | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 				if !db.DryRun { | ||||
| 					result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if err == nil { | ||||
| 					if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { | ||||
| 						if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { | ||||
| 							if insertID, err := result.LastInsertId(); err == nil { | ||||
| 								switch db.Statement.ReflectValue.Kind() { | ||||
| 								case reflect.Slice, reflect.Array: | ||||
| 									if config.LastInsertIDReversed { | ||||
| 										for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 											db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 											insertID-- | ||||
| 										} | ||||
| 									} else { | ||||
| 										for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 											db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 											insertID++ | ||||
| 					if err == nil { | ||||
| 						if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { | ||||
| 							if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { | ||||
| 								if insertID, err := result.LastInsertId(); err == nil { | ||||
| 									switch db.Statement.ReflectValue.Kind() { | ||||
| 									case reflect.Slice, reflect.Array: | ||||
| 										if config.LastInsertIDReversed { | ||||
| 											for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 												db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 												insertID-- | ||||
| 											} | ||||
| 										} else { | ||||
| 											for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 												db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 												insertID++ | ||||
| 											} | ||||
| 										} | ||||
| 									case reflect.Struct: | ||||
| 										db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 									} | ||||
| 								case reflect.Struct: | ||||
| 									db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 								} else { | ||||
| 									db.AddError(err) | ||||
| 								} | ||||
| 							} else { | ||||
| 								db.AddError(err) | ||||
| 							} | ||||
| 						} | ||||
| 						db.RowsAffected, _ = result.RowsAffected() | ||||
| 					} else { | ||||
| 						db.AddError(err) | ||||
| 					} | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} else { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| @ -135,42 +137,44 @@ func CreateWithReturning(db *gorm.DB) { | ||||
| 				idx++ | ||||
| 			} | ||||
| 
 | ||||
| 			rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 			if !db.DryRun { | ||||
| 				rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 			if err == nil { | ||||
| 				defer rows.Close() | ||||
| 				if err == nil { | ||||
| 					defer rows.Close() | ||||
| 
 | ||||
| 				switch db.Statement.ReflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for rows.Next() { | ||||
| 					switch db.Statement.ReflectValue.Kind() { | ||||
| 					case reflect.Slice, reflect.Array: | ||||
| 						for rows.Next() { | ||||
| 							for idx, field := range fields { | ||||
| 								values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||
| 							} | ||||
| 							if err := rows.Scan(values...); err != nil { | ||||
| 								db.AddError(err) | ||||
| 							} | ||||
| 							db.RowsAffected++ | ||||
| 						} | ||||
| 					case reflect.Struct: | ||||
| 						for idx, field := range fields { | ||||
| 							values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||
| 							values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 						} | ||||
| 						if err := rows.Scan(values...); err != nil { | ||||
| 							db.AddError(err) | ||||
| 						} | ||||
| 						db.RowsAffected++ | ||||
| 					} | ||||
| 				case reflect.Struct: | ||||
| 					for idx, field := range fields { | ||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 					} | ||||
| 
 | ||||
| 					if rows.Next() { | ||||
| 						db.RowsAffected++ | ||||
| 						err = rows.Scan(values...) | ||||
| 						if rows.Next() { | ||||
| 							db.RowsAffected++ | ||||
| 							err = rows.Scan(values...) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if err != nil { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} else { | ||||
| 			if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 				if err != nil { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} else { | ||||
| 				db.AddError(err) | ||||
| 				if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} else { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -72,12 +72,14 @@ func Delete(db *gorm.DB) { | ||||
| 			db.Statement.Build("DELETE", "FROM", "WHERE") | ||||
| 		} | ||||
| 
 | ||||
| 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 		if !db.DryRun { | ||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 		if err == nil { | ||||
| 			db.RowsAffected, _ = result.RowsAffected() | ||||
| 		} else { | ||||
| 			db.AddError(err) | ||||
| 			if err == nil { | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 			} else { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -23,14 +23,16 @@ func Query(db *gorm.DB) { | ||||
| 			BuildQuerySQL(db) | ||||
| 		} | ||||
| 
 | ||||
| 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 		if err != nil { | ||||
| 			db.AddError(err) | ||||
| 			return | ||||
| 		} | ||||
| 		defer rows.Close() | ||||
| 		if !db.DryRun { | ||||
| 			rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 			if err != nil { | ||||
| 				db.AddError(err) | ||||
| 				return | ||||
| 			} | ||||
| 			defer rows.Close() | ||||
| 
 | ||||
| 		gorm.Scan(rows, db, false) | ||||
| 			gorm.Scan(rows, db, false) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -85,12 +85,14 @@ func Update(db *gorm.DB) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 		if !db.DryRun { | ||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 		if err == nil { | ||||
| 			db.RowsAffected, _ = result.RowsAffected() | ||||
| 		} else { | ||||
| 			db.AddError(err) | ||||
| 			if err == nil { | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 			} else { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -61,41 +61,43 @@ func Create(db *gorm.DB) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	if !db.DryRun { | ||||
| 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		defer rows.Close() | ||||
| 		if err == nil { | ||||
| 			defer rows.Close() | ||||
| 
 | ||||
| 		switch db.Statement.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||
| 				values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||
| 					values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) | ||||
| 
 | ||||
| 					for rows.Next() { | ||||
| 						for idx, field := range db.Statement.Schema.PrimaryFields { | ||||
| 							values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||
| 						} | ||||
| 
 | ||||
| 						db.RowsAffected++ | ||||
| 						db.AddError(rows.Scan(values...)) | ||||
| 					} | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||
| 					values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) | ||||
| 
 | ||||
| 				for rows.Next() { | ||||
| 					for idx, field := range db.Statement.Schema.PrimaryFields { | ||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 					} | ||||
| 
 | ||||
| 					db.RowsAffected++ | ||||
| 					db.AddError(rows.Scan(values...)) | ||||
| 				} | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			if len(db.Statement.Schema.PrimaryFields) > 0 { | ||||
| 				values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) | ||||
| 
 | ||||
| 				for idx, field := range db.Statement.Schema.PrimaryFields { | ||||
| 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 				} | ||||
| 
 | ||||
| 				if rows.Next() { | ||||
| 					db.RowsAffected++ | ||||
| 					db.AddError(rows.Scan(values...)) | ||||
| 					if rows.Next() { | ||||
| 						db.RowsAffected++ | ||||
| 						db.AddError(rows.Scan(values...)) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			db.AddError(err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		db.AddError(err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										7
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								gorm.go
									
									
									
									
									
								
							| @ -22,6 +22,8 @@ type Config struct { | ||||
| 	Logger logger.Interface | ||||
| 	// NowFunc the function to be used when creating a new timestamp
 | ||||
| 	NowFunc func() time.Time | ||||
| 	// DryRun generate sql without execute
 | ||||
| 	DryRun bool | ||||
| 
 | ||||
| 	// ClauseBuilders clause builder
 | ||||
| 	ClauseBuilders map[string]clause.ClauseBuilder | ||||
| @ -45,6 +47,7 @@ type DB struct { | ||||
| 
 | ||||
| // Session session config when create session with Session() method
 | ||||
| type Session struct { | ||||
| 	DryRun         bool | ||||
| 	WithConditions bool | ||||
| 	Context        context.Context | ||||
| 	Logger         logger.Interface | ||||
| @ -120,6 +123,10 @@ func (db *DB) Session(config *Session) *DB { | ||||
| 		tx.clone = 3 | ||||
| 	} | ||||
| 
 | ||||
| 	if config.DryRun { | ||||
| 		tx.Config.DryRun = true | ||||
| 	} | ||||
| 
 | ||||
| 	if config.Logger != nil { | ||||
| 		tx.Config.Logger = config.Logger | ||||
| 	} | ||||
|  | ||||
| @ -22,8 +22,10 @@ func isPrintable(s []byte) bool { | ||||
| 
 | ||||
| var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} | ||||
| 
 | ||||
| func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { | ||||
| func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { | ||||
| 	var convertParams func(interface{}, int) | ||||
| 	var vars = make([]interface{}, len(avars)) | ||||
| 	copy(vars, avars) | ||||
| 
 | ||||
| 	convertParams = func(v interface{}, idx int) { | ||||
| 		switch v := v.(type) { | ||||
|  | ||||
							
								
								
									
										14
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								statement.go
									
									
									
									
									
								
							| @ -157,6 +157,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 			} else { | ||||
| 				writer.WriteString("(NULL)") | ||||
| 			} | ||||
| 		case *DB: | ||||
| 			result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement | ||||
| 			writer.WriteString(result.SQL.String()) | ||||
| 			stmt.Vars = append(stmt.Vars, result.Vars...) | ||||
| 		default: | ||||
| 			switch rv := reflect.ValueOf(v); rv.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| @ -226,7 +230,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con | ||||
| 		case clause.Expression: | ||||
| 			conds = append(conds, v) | ||||
| 		case *DB: | ||||
| 			if v.Statement == nil { | ||||
| 			if v.Statement != nil { | ||||
| 				if cs, ok := v.Statement.Clauses["WHERE"]; ok { | ||||
| 					conds = append(conds, cs.Expression) | ||||
| 				} | ||||
| @ -367,7 +371,9 @@ func (stmt *Statement) reinit() { | ||||
| 	// })
 | ||||
| 
 | ||||
| 	// stmt.Schema = nil
 | ||||
| 	stmt.SQL.Reset() | ||||
| 	stmt.Vars = nil | ||||
| 	stmt.NamedVars = nil | ||||
| 	if !stmt.DB.DryRun { | ||||
| 		stmt.SQL.Reset() | ||||
| 		stmt.Vars = nil | ||||
| 		stmt.NamedVars = nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -349,3 +349,89 @@ func TestSearchWithMap(t *testing.T) { | ||||
| 		t.Errorf("Search all records with inline multiple value map") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSubQuery(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{Name: "subquery_1", Age: 10}, | ||||
| 		{Name: "subquery_2", Age: 20}, | ||||
| 		{Name: "subquery_3", Age: 30}, | ||||
| 		{Name: "subquery_4", Age: 40}, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil { | ||||
| 		t.Fatalf("got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(users) != 4 { | ||||
| 		t.Errorf("Four users should be found, instead found %d", len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB. | ||||
| 		Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users) | ||||
| 
 | ||||
| 	if len(users) != 2 { | ||||
| 		t.Errorf("Two users should be found, instead found %d", len(users)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSubQueryWithRaw(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{Name: "subquery_raw_1", Age: 10}, | ||||
| 		{Name: "subquery_raw_2", Age: 20}, | ||||
| 		{Name: "subquery_raw_3", Age: 30}, | ||||
| 		{Name: "subquery_raw_4", Age: 40}, | ||||
| 	} | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	var count int64 | ||||
| 	err := DB.Raw("select count(*) from (?) tmp", | ||||
| 		DB.Table("users"). | ||||
| 			Select("name"). | ||||
| 			Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). | ||||
| 			Group("name"), | ||||
| 	).Count(&count).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Expected to get no errors, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count != 1 { | ||||
| 		t.Errorf("Row count must be 1, instead got %d", count) | ||||
| 	} | ||||
| 
 | ||||
| 	err = DB.Raw("select count(*) from (?) tmp", | ||||
| 		DB.Table("users"). | ||||
| 			Select("name"). | ||||
| 			Where("name LIKE ?", "subquery_raw%"). | ||||
| 			Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). | ||||
| 			Group("name"), | ||||
| 	).Count(&count).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Expected to get no errors, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count != 2 { | ||||
| 		t.Errorf("Row count must be 2, instead got %d", count) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSubQueryWithHaving(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{Name: "subquery_having_1", Age: 10}, | ||||
| 		{Name: "subquery_having_2", Age: 20}, | ||||
| 		{Name: "subquery_having_3", Age: 30}, | ||||
| 		{Name: "subquery_having_4", Age: 40}, | ||||
| 	} | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	var results []User | ||||
| 	DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB. | ||||
| 		Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results) | ||||
| 
 | ||||
| 	if len(results) != 2 { | ||||
| 		t.Errorf("Two user group should be found, instead found %d", len(results)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu