Allow Where select fields when searching with struct
This commit is contained in:
		
							parent
							
								
									79864af9ff
								
							
						
					
					
						commit
						1b8cb07cf2
					
				
							
								
								
									
										26
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								statement.go
									
									
									
									
									
								
							| @ -250,7 +250,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 
 | ||||
| 	conds := make([]clause.Expression, 0, 4) | ||||
| 	args = append([]interface{}{query}, args...) | ||||
| 	for _, arg := range args { | ||||
| 	for idx, arg := range args { | ||||
| 		if valuer, ok := arg.(driver.Valuer); ok { | ||||
| 			arg, _ = valuer.Value() | ||||
| 		} | ||||
| @ -310,11 +310,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 		default: | ||||
| 			reflectValue := reflect.Indirect(reflect.ValueOf(arg)) | ||||
| 			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { | ||||
| 				selectedColumns := map[string]bool{} | ||||
| 				if idx == 0 { | ||||
| 					for _, v := range args[1:] { | ||||
| 						if vs, ok := v.(string); ok { | ||||
| 							selectedColumns[vs] = true | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 				restricted := len(selectedColumns) != 0 | ||||
| 
 | ||||
| 				switch reflectValue.Kind() { | ||||
| 				case reflect.Struct: | ||||
| 					for _, field := range s.Fields { | ||||
| 						if field.Readable { | ||||
| 							if v, isZero := field.ValueOf(reflectValue); !isZero { | ||||
| 						selected := selectedColumns[field.DBName] || selectedColumns[field.Name] | ||||
| 						if selected || (!restricted && field.Readable) { | ||||
| 							if v, isZero := field.ValueOf(reflectValue); !isZero || selected { | ||||
| 								if field.DBName != "" { | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | ||||
| 								} else if field.DataType != "" { | ||||
| @ -326,8 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for i := 0; i < reflectValue.Len(); i++ { | ||||
| 						for _, field := range s.Fields { | ||||
| 							if field.Readable { | ||||
| 								if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { | ||||
| 							selected := selectedColumns[field.DBName] || selectedColumns[field.Name] | ||||
| 							if selected || (!restricted && field.Readable) { | ||||
| 								if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { | ||||
| 									if field.DBName != "" { | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | ||||
| 									} else if field.DataType != "" { | ||||
| @ -338,6 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if restricted { | ||||
| 					break | ||||
| 				} | ||||
| 			} else if len(conds) == 0 { | ||||
| 				if len(args) == 1 { | ||||
| 					switch reflectValue.Kind() { | ||||
|  | ||||
| @ -921,6 +921,30 @@ func TestSearchWithMap(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSearchWithStruct(t *testing.T) { | ||||
| 	dryRunDB := DB.Session(&gorm.Session{DryRun: true}) | ||||
| 
 | ||||
| 	result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) | ||||
| 	if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { | ||||
| 		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) | ||||
| 	if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { | ||||
| 		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) | ||||
| 	if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { | ||||
| 		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) | ||||
| 	if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { | ||||
| 		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSubQuery(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{Name: "subquery_1", Age: 10}, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu